Esempio n. 1
0
def test_skip_gram_sample_random_skips():
    """Tests skip-gram with min_skips != max_skips, with random output."""
    # The number of outputs is non-deterministic in this case, so set random
    # seed to help ensure the outputs remain constant for this test case.
    tf.random.set_seed(42)

    input_tensor = tf.constant([b"the", b"quick", b"brown", b"fox", b"jumps", b"over"])
    tokens, labels = text.skip_gram_sample(
        input_tensor, min_skips=1, max_skips=2, seed=9
    )
    expected_tokens, expected_labels = _split_tokens_labels(
        [
            (b"the", b"quick"),
            (b"the", b"brown"),
            (b"quick", b"the"),
            (b"quick", b"brown"),
            (b"quick", b"fox"),
            (b"brown", b"the"),
            (b"brown", b"quick"),
            (b"brown", b"fox"),
            (b"brown", b"jumps"),
            (b"fox", b"brown"),
            (b"fox", b"jumps"),
            (b"jumps", b"fox"),
            (b"jumps", b"over"),
            (b"over", b"fox"),
            (b"over", b"jumps"),
        ]
    )
    np.testing.assert_equal(np.asanyarray(expected_tokens), tokens.numpy())
    np.testing.assert_equal(np.asanyarray(expected_labels), labels.numpy())
Esempio n. 2
0
 def test_skip_gram_sample_emit_self(self):
     """Tests skip-gram with emit_self_as_target = True."""
     input_tensor = tf.constant([b"the", b"quick", b"brown", b"fox", b"jumps"])
     tokens, labels = text.skip_gram_sample(
         input_tensor, min_skips=2, max_skips=2, emit_self_as_target=True
     )
     expected_tokens, expected_labels = self._split_tokens_labels(
         [
             (b"the", b"the"),
             (b"the", b"quick"),
             (b"the", b"brown"),
             (b"quick", b"the"),
             (b"quick", b"quick"),
             (b"quick", b"brown"),
             (b"quick", b"fox"),
             (b"brown", b"the"),
             (b"brown", b"quick"),
             (b"brown", b"brown"),
             (b"brown", b"fox"),
             (b"brown", b"jumps"),
             (b"fox", b"quick"),
             (b"fox", b"brown"),
             (b"fox", b"fox"),
             (b"fox", b"jumps"),
             (b"jumps", b"brown"),
             (b"jumps", b"fox"),
             (b"jumps", b"jumps"),
         ]
     )
     self.assertAllEqual(expected_tokens, tokens)
     self.assertAllEqual(expected_labels, labels)
 def test_skip_gram_sample_skips_2(self):
     """Tests skip-gram with min_skips = max_skips = 2."""
     input_tensor = constant_op.constant(
         [b"the", b"quick", b"brown", b"fox", b"jumps"])
     tokens, labels = text.skip_gram_sample(input_tensor,
                                            min_skips=2,
                                            max_skips=2)
     expected_tokens, expected_labels = self._split_tokens_labels([
         (b"the", b"quick"),
         (b"the", b"brown"),
         (b"quick", b"the"),
         (b"quick", b"brown"),
         (b"quick", b"fox"),
         (b"brown", b"the"),
         (b"brown", b"quick"),
         (b"brown", b"fox"),
         (b"brown", b"jumps"),
         (b"fox", b"quick"),
         (b"fox", b"brown"),
         (b"fox", b"jumps"),
         (b"jumps", b"brown"),
         (b"jumps", b"fox"),
     ])
     self.assertAllEqual(expected_tokens, tokens)
     self.assertAllEqual(expected_labels, labels)
    def test_skip_gram_sample_random_skips(self):
        """Tests skip-gram with min_skips != max_skips, with random output."""
        # The number of outputs is non-deterministic in this case, so set random
        # seed to help ensure the outputs remain constant for this test case.
        random_seed.set_random_seed(42)

        input_tensor = constant_op.constant(
            [b"the", b"quick", b"brown", b"fox", b"jumps", b"over"])
        tokens, labels = text.skip_gram_sample(input_tensor,
                                               min_skips=1,
                                               max_skips=2,
                                               seed=9)
        expected_tokens, expected_labels = self._split_tokens_labels([
            (b"the", b"quick"),
            (b"the", b"brown"),
            (b"quick", b"the"),
            (b"quick", b"brown"),
            (b"quick", b"fox"),
            (b"brown", b"the"),
            (b"brown", b"quick"),
            (b"brown", b"fox"),
            (b"brown", b"jumps"),
            (b"fox", b"brown"),
            (b"fox", b"jumps"),
            (b"jumps", b"fox"),
            (b"jumps", b"over"),
            (b"over", b"fox"),
            (b"over", b"jumps"),
        ])
        self.assertAllEqual(expected_tokens, tokens)
        self.assertAllEqual(expected_labels, labels)
Esempio n. 5
0
def test_skip_gram_sample_non_string_input():
    """Tests skip-gram with non-string input."""
    input_tensor = tf.constant([1, 2, 3], dtype=tf.dtypes.int16)
    tokens, labels = text.skip_gram_sample(input_tensor, min_skips=1, max_skips=1)
    expected_tokens, expected_labels = _split_tokens_labels(
        [(1, 2), (2, 1), (2, 3), (3, 2)]
    )
    np.testing.assert_equal(np.asanyarray(expected_tokens), tokens.numpy())
    np.testing.assert_equal(np.asanyarray(expected_labels), labels.numpy())
Esempio n. 6
0
 def test_skip_gram_sample_non_string_input(self):
     """Tests skip-gram with non-string input."""
     input_tensor = tf.constant([1, 2, 3], dtype=tf.dtypes.int16)
     tokens, labels = text.skip_gram_sample(input_tensor, min_skips=1, max_skips=1)
     expected_tokens, expected_labels = self._split_tokens_labels(
         [(1, 2), (2, 1), (2, 3), (3, 2),]
     )
     self.assertAllEqual(expected_tokens, tokens)
     self.assertAllEqual(expected_labels, labels)
Esempio n. 7
0
def test_skip_gram_sample_skips_0():
    """Tests skip-gram with min_skips = max_skips = 0."""
    input_tensor = tf.constant([b"the", b"quick", b"brown"])

    # If emit_self_as_target is False (default), output will be empty.
    tokens, labels = text.skip_gram_sample(
        input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False
    )
    assert 0 == len(tokens)
    assert 0 == len(labels)

    # If emit_self_as_target is True, each token will be its own label.
    tokens, labels = text.skip_gram_sample(
        input_tensor, min_skips=0, max_skips=0, emit_self_as_target=True
    )
    expected_tokens, expected_labels = _split_tokens_labels(
        [(b"the", b"the"), (b"quick", b"quick"), (b"brown", b"brown")]
    )
    np.testing.assert_equal(np.asanyarray(expected_tokens), tokens.numpy())
    np.testing.assert_equal(np.asanyarray(expected_labels), labels.numpy())
Esempio n. 8
0
    def test_skip_gram_sample_skips_0(self):
        """Tests skip-gram with min_skips = max_skips = 0."""
        input_tensor = tf.constant([b"the", b"quick", b"brown"])

        # If emit_self_as_target is False (default), output will be empty.
        tokens, labels = text.skip_gram_sample(
            input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False)
        with self.cached_session():
            self.assertEqual(0, len(tokens))
            self.assertEqual(0, len(labels))

        # If emit_self_as_target is True, each token will be its own label.
        tokens, labels = text.skip_gram_sample(
            input_tensor, min_skips=0, max_skips=0, emit_self_as_target=True)
        expected_tokens, expected_labels = self._split_tokens_labels([
            (b"the", b"the"),
            (b"quick", b"quick"),
            (b"brown", b"brown"),
        ])
        self.assertAllEqual(expected_tokens, tokens)
        self.assertAllEqual(expected_labels, labels)
Esempio n. 9
0
 def test_skip_gram_sample_limit_exceeds(self):
     """Tests skip-gram when limit exceeds the length of the input."""
     input_tensor = tf.constant([b"foo", b"the", b"quick", b"brown"])
     tokens, labels = text.skip_gram_sample(
         input_tensor, min_skips=1, max_skips=1, start=1, limit=100)
     expected_tokens, expected_labels = self._split_tokens_labels([
         (b"the", b"quick"),
         (b"quick", b"the"),
         (b"quick", b"brown"),
         (b"brown", b"quick"),
     ])
     self.assertAllEqual(expected_tokens, tokens)
     self.assertAllEqual(expected_labels, labels)
Esempio n. 10
0
 def test_skip_gram_sample_start_limit(self):
     """Tests skip-gram over a limited portion of the input."""
     input_tensor = tf.constant(
         [b"foo", b"the", b"quick", b"brown", b"bar"])
     tokens, labels = text.skip_gram_sample(
         input_tensor, min_skips=1, max_skips=1, start=1, limit=3)
     expected_tokens, expected_labels = self._split_tokens_labels([
         (b"the", b"quick"),
         (b"quick", b"the"),
         (b"quick", b"brown"),
         (b"brown", b"quick"),
     ])
     self.assertAllEqual(expected_tokens, tokens)
     self.assertAllEqual(expected_labels, labels)
Esempio n. 11
0
 def test_skip_gram_sample_skips_exceed_length(self):
     """Tests skip-gram when min/max_skips exceed length of input."""
     input_tensor = tf.constant([b"the", b"quick", b"brown"])
     tokens, labels = text.skip_gram_sample(
         input_tensor, min_skips=100, max_skips=100)
     expected_tokens, expected_labels = self._split_tokens_labels([
         (b"the", b"quick"),
         (b"the", b"brown"),
         (b"quick", b"the"),
         (b"quick", b"brown"),
         (b"brown", b"the"),
         (b"brown", b"quick"),
     ])
     self.assertAllEqual(expected_tokens, tokens)
     self.assertAllEqual(expected_labels, labels)
Esempio n. 12
0
def test_skip_gram_sample_limit_exceeds():
    """Tests skip-gram when limit exceeds the length of the input."""
    input_tensor = tf.constant([b"foo", b"the", b"quick", b"brown"])
    tokens, labels = text.skip_gram_sample(input_tensor,
                                           min_skips=1,
                                           max_skips=1,
                                           start=1,
                                           limit=100)
    expected_tokens, expected_labels = _split_tokens_labels([
        (b"the", b"quick"),
        (b"quick", b"the"),
        (b"quick", b"brown"),
        (b"brown", b"quick"),
    ])
    np.testing.assert_equal(np.asanyarray(expected_tokens), tokens.numpy())
    np.testing.assert_equal(np.asanyarray(expected_labels), labels.numpy())
Esempio n. 13
0
def test_skip_gram_sample_skips_2():
    """Tests skip-gram with min_skips = max_skips = 2."""
    input_tensor = tf.constant([b"the", b"quick", b"brown", b"fox", b"jumps"])
    tokens, labels = text.skip_gram_sample(input_tensor,
                                           min_skips=2,
                                           max_skips=2)
    expected_tokens, expected_labels = _split_tokens_labels([
        (b"the", b"quick"),
        (b"the", b"brown"),
        (b"quick", b"the"),
        (b"quick", b"brown"),
        (b"quick", b"fox"),
        (b"brown", b"the"),
        (b"brown", b"quick"),
        (b"brown", b"fox"),
        (b"brown", b"jumps"),
        (b"fox", b"quick"),
        (b"fox", b"brown"),
        (b"fox", b"jumps"),
        (b"jumps", b"brown"),
        (b"jumps", b"fox"),
    ])
    np.testing.assert_equal(np.asanyarray(expected_tokens), tokens.numpy())
    np.testing.assert_equal(np.asanyarray(expected_labels), labels.numpy())
Esempio n. 14
0
    def test_skip_gram_sample_errors(self):
        """Tests various errors raised by skip_gram_sample()."""
        input_tensor = tf.constant([b"the", b"quick", b"brown"])

        invalid_skips = (
            # min_skips and max_skips must be >= 0.
            (-1, 2),
            (1, -2),
            # min_skips must be <= max_skips.
            (2, 1))
        for min_skips, max_skips in invalid_skips:
            with self.assertRaises(tf.errors.InvalidArgumentError):
                text.skip_gram_sample(
                    input_tensor, min_skips=min_skips, max_skips=max_skips)

        # Eager tensor must be rank 1
        with self.assertRaises(tf.errors.InvalidArgumentError):
            invalid_tensor = tf.constant([[b"the"], [b"quick"], [b"brown"]])
            text.skip_gram_sample(invalid_tensor)

        # vocab_freq_table must be provided if vocab_min_count,
        # vocab_subsampling, or corpus_size is specified.
        dummy_input = tf.constant([""])
        with self.assertRaises(ValueError):
            text.skip_gram_sample(
                dummy_input, vocab_freq_table=None, vocab_min_count=1)
        with self.assertRaises(ValueError):
            text.skip_gram_sample(
                dummy_input, vocab_freq_table=None, vocab_subsampling=1e-5)
        with self.assertRaises(ValueError):
            text.skip_gram_sample(
                dummy_input, vocab_freq_table=None, corpus_size=100)
        with self.assertRaises(ValueError):
            text.skip_gram_sample(
                dummy_input,
                vocab_freq_table=None,
                vocab_subsampling=1e-5,
                corpus_size=100)

        # vocab_subsampling and corpus_size must both be present or absent.
        dummy_table = tf.lookup.StaticHashTable(
            tf.lookup.KeyValueTensorInitializer([b"foo"], [10]), -1)
        with self.assertRaises(ValueError):
            text.skip_gram_sample(
                dummy_input,
                vocab_freq_table=dummy_table,
                vocab_subsampling=None,
                corpus_size=100)
        with self.assertRaises(ValueError):
            text.skip_gram_sample(
                dummy_input,
                vocab_freq_table=dummy_table,
                vocab_subsampling=1e-5,
                corpus_size=None)
Esempio n. 15
0
 def test_skip_gram_sample_errors_v1(self):
     """Tests various errors raised by skip_gram_sample()."""
     # input_tensor must be of rank 1.
     with self.assertRaises(ValueError):
         invalid_tensor = tf.constant([[b"the"], [b"quick"], [b"brown"]])
         text.skip_gram_sample(invalid_tensor)
Esempio n. 16
0
def test_skip_gram_sample_errors_v1():
    """Tests various errors raised by skip_gram_sample()."""
    # input_tensor must be of rank 1.
    with pytest.raises(tf.errors.InvalidArgumentError):
        invalid_tensor = tf.constant([[b"the"], [b"quick"], [b"brown"]])
        text.skip_gram_sample(invalid_tensor)