Exemple #1
0
    def test_skip_gram_sample_skips_0(self):
        """Tests skip-gram with min_skips = max_skips = 0."""
        input_tensor = tf.constant([b"the", b"quick", b"brown"])

        # If emit_self_as_target is False (default), output will be empty.
        tokens, labels = text.skip_gram_sample(input_tensor,
                                               min_skips=0,
                                               max_skips=0,
                                               emit_self_as_target=False)
        with self.cached_session():
            self.assertEqual(0, len(tokens))
            self.assertEqual(0, len(labels))

        # If emit_self_as_target is True, each token will be its own label.
        tokens, labels = text.skip_gram_sample(input_tensor,
                                               min_skips=0,
                                               max_skips=0,
                                               emit_self_as_target=True)
        expected_tokens, expected_labels = self._split_tokens_labels([
            (b"the", b"the"),
            (b"quick", b"quick"),
            (b"brown", b"brown"),
        ])
        self.assertAllEqual(expected_tokens, tokens)
        self.assertAllEqual(expected_labels, labels)
Exemple #2
0
 def test_skip_gram_sample_errors_v1(self):
     """Tests various errors raised by skip_gram_sample()."""
     # input_tensor must be of rank 1.
     with self.assertRaises(ValueError):
         invalid_tensor = tf.constant([[b"the"], [b"quick"],
                                                [b"brown"]])
         text.skip_gram_sample(invalid_tensor)
Exemple #3
0
    def test_skip_gram_sample_random_skips_default_seed(self):
        """Tests outputs are still random when no op-level seed is
        specified."""

        # This is needed since tests set a graph-level seed by default. We want
        # to explicitly avoid setting both graph-level seed and op-level seed,
        # to simulate behavior under non-test settings when the user doesn't
        # provide a seed to us. This results in random_seed.get_seed() returning
        # None for both seeds, forcing the C++ kernel to execute its default
        # seed logic.
        random_seed.set_random_seed(None)

        # Uses an input tensor with 10 words, with possible skip ranges in
        # [1, 5]. Thus, the probability that two random samplings would result
        # in the same outputs is 1/5^10 ~ 1e-7 (aka the probability of this test
        # being flaky).
        input_tensor = tf.constant([str(x) for x in range(10)])

        # Do not provide an op-level seed here!
        tokens_1, labels_1 = text.skip_gram_sample(input_tensor,
                                                   min_skips=1,
                                                   max_skips=5)
        tokens_2, labels_2 = text.skip_gram_sample(input_tensor,
                                                   min_skips=1,
                                                   max_skips=5)

        if len(tokens_1) == len(tokens_2):
            self.assertNotEqual(list(tokens_1), list(tokens_2))
        if len(labels_1) == len(labels_2):
            self.assertNotEqual(list(labels_1), list(labels_2))
Exemple #4
0
 def test_skip_gram_sample_emit_self(self):
     """Tests skip-gram with emit_self_as_target = True."""
     input_tensor = tf.constant(
         [b"the", b"quick", b"brown", b"fox", b"jumps"])
     tokens, labels = text.skip_gram_sample(input_tensor,
                                            min_skips=2,
                                            max_skips=2,
                                            emit_self_as_target=True)
     expected_tokens, expected_labels = self._split_tokens_labels([
         (b"the", b"the"),
         (b"the", b"quick"),
         (b"the", b"brown"),
         (b"quick", b"the"),
         (b"quick", b"quick"),
         (b"quick", b"brown"),
         (b"quick", b"fox"),
         (b"brown", b"the"),
         (b"brown", b"quick"),
         (b"brown", b"brown"),
         (b"brown", b"fox"),
         (b"brown", b"jumps"),
         (b"fox", b"quick"),
         (b"fox", b"brown"),
         (b"fox", b"fox"),
         (b"fox", b"jumps"),
         (b"jumps", b"brown"),
         (b"jumps", b"fox"),
         (b"jumps", b"jumps"),
     ])
     self.assertAllEqual(expected_tokens, tokens)
     self.assertAllEqual(expected_labels, labels)
Exemple #5
0
    def test_skip_gram_sample_random_skips(self):
        """Tests skip-gram with min_skips != max_skips, with random output."""
        # The number of outputs is non-deterministic in this case, so set random
        # seed to help ensure the outputs remain constant for this test case.
        random_seed.set_random_seed(42)

        input_tensor = tf.constant(
            [b"the", b"quick", b"brown", b"fox", b"jumps", b"over"])
        tokens, labels = text.skip_gram_sample(input_tensor,
                                               min_skips=1,
                                               max_skips=2,
                                               seed=9)
        expected_tokens, expected_labels = self._split_tokens_labels([
            (b"the", b"quick"),
            (b"the", b"brown"),
            (b"quick", b"the"),
            (b"quick", b"brown"),
            (b"quick", b"fox"),
            (b"brown", b"the"),
            (b"brown", b"quick"),
            (b"brown", b"fox"),
            (b"brown", b"jumps"),
            (b"fox", b"brown"),
            (b"fox", b"jumps"),
            (b"jumps", b"fox"),
            (b"jumps", b"over"),
            (b"over", b"fox"),
            (b"over", b"jumps"),
        ])
        self.assertAllEqual(expected_tokens, tokens)
        self.assertAllEqual(expected_labels, labels)
Exemple #6
0
 def test_skip_gram_sample_non_string_input(self):
     """Tests skip-gram with non-string input."""
     input_tensor = tf.constant([1, 2, 3], dtype=tf.dtypes.int16)
     tokens, labels = text.skip_gram_sample(
         input_tensor, min_skips=1, max_skips=1)
     expected_tokens, expected_labels = self._split_tokens_labels([
         (1, 2),
         (2, 1),
         (2, 3),
         (3, 2),
     ])
     self.assertAllEqual(expected_tokens, tokens)
     self.assertAllEqual(expected_labels, labels)
Exemple #7
0
 def test_skip_gram_sample_limit_exceeds(self):
     """Tests skip-gram when limit exceeds the length of the input."""
     input_tensor = tf.constant([b"foo", b"the",
                                          b"quick", b"brown"])
     tokens, labels = text.skip_gram_sample(
         input_tensor, min_skips=1, max_skips=1, start=1, limit=100)
     expected_tokens, expected_labels = self._split_tokens_labels([
         (b"the", b"quick"),
         (b"quick", b"the"),
         (b"quick", b"brown"),
         (b"brown", b"quick"),
     ])
     self.assertAllEqual(expected_tokens, tokens)
     self.assertAllEqual(expected_labels, labels)
Exemple #8
0
 def test_skip_gram_sample_start_limit(self):
     """Tests skip-gram over a limited portion of the input."""
     input_tensor = tf.constant(
         [b"foo", b"the", b"quick", b"brown", b"bar"])
     tokens, labels = text.skip_gram_sample(
         input_tensor, min_skips=1, max_skips=1, start=1, limit=3)
     expected_tokens, expected_labels = self._split_tokens_labels([
         (b"the", b"quick"),
         (b"quick", b"the"),
         (b"quick", b"brown"),
         (b"brown", b"quick"),
     ])
     self.assertAllEqual(expected_tokens, tokens)
     self.assertAllEqual(expected_labels, labels)
Exemple #9
0
 def test_skip_gram_sample_skips_exceed_length(self):
     """Tests skip-gram when min/max_skips exceed length of input."""
     input_tensor = tf.constant([b"the", b"quick", b"brown"])
     tokens, labels = text.skip_gram_sample(
         input_tensor, min_skips=100, max_skips=100)
     expected_tokens, expected_labels = self._split_tokens_labels([
         (b"the", b"quick"),
         (b"the", b"brown"),
         (b"quick", b"the"),
         (b"quick", b"brown"),
         (b"brown", b"the"),
         (b"brown", b"quick"),
     ])
     self.assertAllEqual(expected_tokens, tokens)
     self.assertAllEqual(expected_labels, labels)
Exemple #10
0
    def test_skip_gram_sample_errors(self):
        """Tests various errors raised by skip_gram_sample()."""
        input_tensor = tf.constant([b"the", b"quick", b"brown"])

        invalid_skips = (
            # min_skips and max_skips must be >= 0.
            (-1, 2),
            (1, -2),
            # min_skips must be <= max_skips.
            (2, 1))
        for min_skips, max_skips in invalid_skips:
            with self.assertRaises(tf.errors.InvalidArgumentError):
                text.skip_gram_sample(input_tensor,
                                      min_skips=min_skips,
                                      max_skips=max_skips)

        # Eager tensor must be rank 1
        with self.assertRaises(tf.errors.InvalidArgumentError):
            invalid_tensor = tf.constant([[b"the"], [b"quick"], [b"brown"]])
            text.skip_gram_sample(invalid_tensor)

        # vocab_freq_table must be provided if vocab_min_count,
        # vocab_subsampling, or corpus_size is specified.
        dummy_input = tf.constant([""])
        with self.assertRaises(ValueError):
            text.skip_gram_sample(dummy_input,
                                  vocab_freq_table=None,
                                  vocab_min_count=1)
        with self.assertRaises(ValueError):
            text.skip_gram_sample(dummy_input,
                                  vocab_freq_table=None,
                                  vocab_subsampling=1e-5)
        with self.assertRaises(ValueError):
            text.skip_gram_sample(dummy_input,
                                  vocab_freq_table=None,
                                  corpus_size=100)
        with self.assertRaises(ValueError):
            text.skip_gram_sample(dummy_input,
                                  vocab_freq_table=None,
                                  vocab_subsampling=1e-5,
                                  corpus_size=100)

        # vocab_subsampling and corpus_size must both be present or absent.
        dummy_table = lookup_ops.HashTable(
            lookup_ops.KeyValueTensorInitializer([b"foo"], [10]), -1)
        with self.assertRaises(ValueError):
            text.skip_gram_sample(dummy_input,
                                  vocab_freq_table=dummy_table,
                                  vocab_subsampling=None,
                                  corpus_size=100)
        with self.assertRaises(ValueError):
            text.skip_gram_sample(dummy_input,
                                  vocab_freq_table=dummy_table,
                                  vocab_subsampling=1e-5,
                                  corpus_size=None)