コード例 #1
0
  def test_skip_gram_sample_random_skips_default_seed(self):
    """Tests outputs are still random when no op-level seed is specified."""
    # This is needed since tests set a graph-level seed by default. We want to
    # explicitly avoid setting both graph-level seed and op-level seed, to
    # simulate behavior under non-test settings when the user doesn't provide a
    # seed to us. This results in random_seed.get_seed() returning None for both
    # seeds, forcing the C++ kernel to execute its default seed logic.
    random_seed.set_random_seed(None)

    # Uses an input tensor with 10 words, with possible skip ranges in [1,
    # 5]. Thus, the probability that two random samplings would result in the
    # same outputs is 1/5^10 ~ 1e-7 (aka the probability of this test being
    # flaky).
    input_tensor = constant_op.constant([str(x) for x in range(10)])

    # Do not provide an op-level seed here!
    tokens_1, labels_1 = text.skip_gram_sample(
        input_tensor, min_skips=1, max_skips=5)
    tokens_2, labels_2 = text.skip_gram_sample(
        input_tensor, min_skips=1, max_skips=5)

    with self.cached_session() as sess:
      tokens_1_eval, labels_1_eval, tokens_2_eval, labels_2_eval = sess.run(
          [tokens_1, labels_1, tokens_2, labels_2])

    if len(tokens_1_eval) == len(tokens_2_eval):
      self.assertNotEqual(tokens_1_eval.tolist(), tokens_2_eval.tolist())
    if len(labels_1_eval) == len(labels_2_eval):
      self.assertNotEqual(labels_1_eval.tolist(), labels_2_eval.tolist())
コード例 #2
0
 def test_skip_gram_sample_emit_self(self):
   """Tests skip-gram with emit_self_as_target = True."""
   input_tensor = constant_op.constant(
       [b"the", b"quick", b"brown", b"fox", b"jumps"])
   tokens, labels = text.skip_gram_sample(
       input_tensor, min_skips=2, max_skips=2, emit_self_as_target=True)
   expected_tokens, expected_labels = self._split_tokens_labels([
       (b"the", b"the"),
       (b"the", b"quick"),
       (b"the", b"brown"),
       (b"quick", b"the"),
       (b"quick", b"quick"),
       (b"quick", b"brown"),
       (b"quick", b"fox"),
       (b"brown", b"the"),
       (b"brown", b"quick"),
       (b"brown", b"brown"),
       (b"brown", b"fox"),
       (b"brown", b"jumps"),
       (b"fox", b"quick"),
       (b"fox", b"brown"),
       (b"fox", b"fox"),
       (b"fox", b"jumps"),
       (b"jumps", b"brown"),
       (b"jumps", b"fox"),
       (b"jumps", b"jumps"),
   ])
   with self.cached_session():
     self.assertAllEqual(expected_tokens, tokens.eval())
     self.assertAllEqual(expected_labels, labels.eval())
コード例 #3
0
  def test_skip_gram_sample_batch(self):
    """Tests skip-gram with batching."""
    input_tensor = constant_op.constant([b"the", b"quick", b"brown", b"fox"])
    tokens, labels = text.skip_gram_sample(
        input_tensor, min_skips=1, max_skips=1, batch_size=3)
    expected_tokens, expected_labels = self._split_tokens_labels([
        (b"the", b"quick"),
        (b"quick", b"the"),
        (b"quick", b"brown"),
        (b"brown", b"quick"),
        (b"brown", b"fox"),
        (b"fox", b"brown"),
    ])
    with self.cached_session() as sess:
      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)

      tokens_eval, labels_eval = sess.run([tokens, labels])
      self.assertAllEqual(expected_tokens[:3], tokens_eval)
      self.assertAllEqual(expected_labels[:3], labels_eval)
      tokens_eval, labels_eval = sess.run([tokens, labels])
      self.assertAllEqual(expected_tokens[3:6], tokens_eval)
      self.assertAllEqual(expected_labels[3:6], labels_eval)

      coord.request_stop()
      coord.join(threads)
コード例 #4
0
  def test_skip_gram_sample_random_skips(self):
    """Tests skip-gram with min_skips != max_skips, with random output."""
    # The number of outputs is non-deterministic in this case, so set random
    # seed to help ensure the outputs remain constant for this test case.
    random_seed.set_random_seed(42)

    input_tensor = constant_op.constant(
        [b"the", b"quick", b"brown", b"fox", b"jumps", b"over"])
    tokens, labels = text.skip_gram_sample(
        input_tensor, min_skips=1, max_skips=2, seed=9)
    expected_tokens, expected_labels = self._split_tokens_labels([
        (b"the", b"quick"),
        (b"the", b"brown"),
        (b"quick", b"the"),
        (b"quick", b"brown"),
        (b"quick", b"fox"),
        (b"brown", b"the"),
        (b"brown", b"quick"),
        (b"brown", b"fox"),
        (b"brown", b"jumps"),
        (b"fox", b"brown"),
        (b"fox", b"jumps"),
        (b"jumps", b"fox"),
        (b"jumps", b"over"),
        (b"over", b"fox"),
        (b"over", b"jumps"),
    ])
    with self.cached_session() as sess:
      tokens_eval, labels_eval = sess.run([tokens, labels])
      self.assertAllEqual(expected_tokens, tokens_eval)
      self.assertAllEqual(expected_labels, labels_eval)
コード例 #5
0
  def test_skip_gram_sample_skips_0(self):
    """Tests skip-gram with min_skips = max_skips = 0."""
    input_tensor = constant_op.constant([b"the", b"quick", b"brown"])

    # If emit_self_as_target is False (default), output will be empty.
    tokens, labels = text.skip_gram_sample(
        input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False)
    with self.cached_session():
      self.assertEqual(0, tokens.eval().size)
      self.assertEqual(0, labels.eval().size)

    # If emit_self_as_target is True, each token will be its own label.
    tokens, labels = text.skip_gram_sample(
        input_tensor, min_skips=0, max_skips=0, emit_self_as_target=True)
    expected_tokens, expected_labels = self._split_tokens_labels([
        (b"the", b"the"),
        (b"quick", b"quick"),
        (b"brown", b"brown"),
    ])
    with self.cached_session():
      self.assertAllEqual(expected_tokens, tokens.eval())
      self.assertAllEqual(expected_labels, labels.eval())
コード例 #6
0
 def test_skip_gram_sample_non_string_input(self):
   """Tests skip-gram with non-string input."""
   input_tensor = constant_op.constant([1, 2, 3], dtype=dtypes.int16)
   tokens, labels = text.skip_gram_sample(
       input_tensor, min_skips=1, max_skips=1)
   expected_tokens, expected_labels = self._split_tokens_labels([
       (1, 2),
       (2, 1),
       (2, 3),
       (3, 2),
   ])
   with self.cached_session():
     self.assertAllEqual(expected_tokens, tokens.eval())
     self.assertAllEqual(expected_labels, labels.eval())
コード例 #7
0
 def test_skip_gram_sample_limit_exceeds(self):
   """Tests skip-gram when limit exceeds the length of the input."""
   input_tensor = constant_op.constant([b"foo", b"the", b"quick", b"brown"])
   tokens, labels = text.skip_gram_sample(
       input_tensor, min_skips=1, max_skips=1, start=1, limit=100)
   expected_tokens, expected_labels = self._split_tokens_labels([
       (b"the", b"quick"),
       (b"quick", b"the"),
       (b"quick", b"brown"),
       (b"brown", b"quick"),
   ])
   with self.cached_session():
     self.assertAllEqual(expected_tokens, tokens.eval())
     self.assertAllEqual(expected_labels, labels.eval())
コード例 #8
0
  def test_skip_gram_sample_errors(self):
    """Tests various errors raised by skip_gram_sample()."""
    input_tensor = constant_op.constant([b"the", b"quick", b"brown"])

    invalid_skips = (
        # min_skips and max_skips must be >= 0.
        (-1, 2),
        (1, -2),
        # min_skips must be <= max_skips.
        (2, 1))
    for min_skips, max_skips in invalid_skips:
      tokens, labels = text.skip_gram_sample(
          input_tensor, min_skips=min_skips, max_skips=max_skips)
      with self.cached_session() as sess, self.assertRaises(
          errors.InvalidArgumentError):
        sess.run([tokens, labels])

    # input_tensor must be of rank 1.
    with self.assertRaises(ValueError):
      invalid_tensor = constant_op.constant([[b"the"], [b"quick"], [b"brown"]])
      text.skip_gram_sample(invalid_tensor)

    # vocab_freq_table must be provided if vocab_min_count, vocab_subsampling,
    # or corpus_size is specified.
    dummy_input = constant_op.constant([""])
    with self.assertRaises(ValueError):
      text.skip_gram_sample(
          dummy_input, vocab_freq_table=None, vocab_min_count=1)
    with self.assertRaises(ValueError):
      text.skip_gram_sample(
          dummy_input, vocab_freq_table=None, vocab_subsampling=1e-5)
    with self.assertRaises(ValueError):
      text.skip_gram_sample(dummy_input, vocab_freq_table=None, corpus_size=100)
    with self.assertRaises(ValueError):
      text.skip_gram_sample(
          dummy_input,
          vocab_freq_table=None,
          vocab_subsampling=1e-5,
          corpus_size=100)

    # vocab_subsampling and corpus_size must both be present or absent.
    dummy_table = lookup.HashTable(
        lookup.KeyValueTensorInitializer([b"foo"], [10]), -1)
    with self.assertRaises(ValueError):
      text.skip_gram_sample(
          dummy_input,
          vocab_freq_table=dummy_table,
          vocab_subsampling=None,
          corpus_size=100)
    with self.assertRaises(ValueError):
      text.skip_gram_sample(
          dummy_input,
          vocab_freq_table=dummy_table,
          vocab_subsampling=1e-5,
          corpus_size=None)