def test_skip_gram_sample_random_skips(): """Tests skip-gram with min_skips != max_skips, with random output.""" # The number of outputs is non-deterministic in this case, so set random # seed to help ensure the outputs remain constant for this test case. tf.random.set_seed(42) input_tensor = tf.constant([b"the", b"quick", b"brown", b"fox", b"jumps", b"over"]) tokens, labels = text.skip_gram_sample( input_tensor, min_skips=1, max_skips=2, seed=9 ) expected_tokens, expected_labels = _split_tokens_labels( [ (b"the", b"quick"), (b"the", b"brown"), (b"quick", b"the"), (b"quick", b"brown"), (b"quick", b"fox"), (b"brown", b"the"), (b"brown", b"quick"), (b"brown", b"fox"), (b"brown", b"jumps"), (b"fox", b"brown"), (b"fox", b"jumps"), (b"jumps", b"fox"), (b"jumps", b"over"), (b"over", b"fox"), (b"over", b"jumps"), ] ) np.testing.assert_equal(np.asanyarray(expected_tokens), tokens.numpy()) np.testing.assert_equal(np.asanyarray(expected_labels), labels.numpy())
def test_skip_gram_sample_emit_self(self): """Tests skip-gram with emit_self_as_target = True.""" input_tensor = tf.constant([b"the", b"quick", b"brown", b"fox", b"jumps"]) tokens, labels = text.skip_gram_sample( input_tensor, min_skips=2, max_skips=2, emit_self_as_target=True ) expected_tokens, expected_labels = self._split_tokens_labels( [ (b"the", b"the"), (b"the", b"quick"), (b"the", b"brown"), (b"quick", b"the"), (b"quick", b"quick"), (b"quick", b"brown"), (b"quick", b"fox"), (b"brown", b"the"), (b"brown", b"quick"), (b"brown", b"brown"), (b"brown", b"fox"), (b"brown", b"jumps"), (b"fox", b"quick"), (b"fox", b"brown"), (b"fox", b"fox"), (b"fox", b"jumps"), (b"jumps", b"brown"), (b"jumps", b"fox"), (b"jumps", b"jumps"), ] ) self.assertAllEqual(expected_tokens, tokens) self.assertAllEqual(expected_labels, labels)
def test_skip_gram_sample_skips_2(self): """Tests skip-gram with min_skips = max_skips = 2.""" input_tensor = constant_op.constant( [b"the", b"quick", b"brown", b"fox", b"jumps"]) tokens, labels = text.skip_gram_sample(input_tensor, min_skips=2, max_skips=2) expected_tokens, expected_labels = self._split_tokens_labels([ (b"the", b"quick"), (b"the", b"brown"), (b"quick", b"the"), (b"quick", b"brown"), (b"quick", b"fox"), (b"brown", b"the"), (b"brown", b"quick"), (b"brown", b"fox"), (b"brown", b"jumps"), (b"fox", b"quick"), (b"fox", b"brown"), (b"fox", b"jumps"), (b"jumps", b"brown"), (b"jumps", b"fox"), ]) self.assertAllEqual(expected_tokens, tokens) self.assertAllEqual(expected_labels, labels)
def test_skip_gram_sample_random_skips(self): """Tests skip-gram with min_skips != max_skips, with random output.""" # The number of outputs is non-deterministic in this case, so set random # seed to help ensure the outputs remain constant for this test case. random_seed.set_random_seed(42) input_tensor = constant_op.constant( [b"the", b"quick", b"brown", b"fox", b"jumps", b"over"]) tokens, labels = text.skip_gram_sample(input_tensor, min_skips=1, max_skips=2, seed=9) expected_tokens, expected_labels = self._split_tokens_labels([ (b"the", b"quick"), (b"the", b"brown"), (b"quick", b"the"), (b"quick", b"brown"), (b"quick", b"fox"), (b"brown", b"the"), (b"brown", b"quick"), (b"brown", b"fox"), (b"brown", b"jumps"), (b"fox", b"brown"), (b"fox", b"jumps"), (b"jumps", b"fox"), (b"jumps", b"over"), (b"over", b"fox"), (b"over", b"jumps"), ]) self.assertAllEqual(expected_tokens, tokens) self.assertAllEqual(expected_labels, labels)
def test_skip_gram_sample_non_string_input(): """Tests skip-gram with non-string input.""" input_tensor = tf.constant([1, 2, 3], dtype=tf.dtypes.int16) tokens, labels = text.skip_gram_sample(input_tensor, min_skips=1, max_skips=1) expected_tokens, expected_labels = _split_tokens_labels( [(1, 2), (2, 1), (2, 3), (3, 2)] ) np.testing.assert_equal(np.asanyarray(expected_tokens), tokens.numpy()) np.testing.assert_equal(np.asanyarray(expected_labels), labels.numpy())
def test_skip_gram_sample_non_string_input(self): """Tests skip-gram with non-string input.""" input_tensor = tf.constant([1, 2, 3], dtype=tf.dtypes.int16) tokens, labels = text.skip_gram_sample(input_tensor, min_skips=1, max_skips=1) expected_tokens, expected_labels = self._split_tokens_labels( [(1, 2), (2, 1), (2, 3), (3, 2),] ) self.assertAllEqual(expected_tokens, tokens) self.assertAllEqual(expected_labels, labels)
def test_skip_gram_sample_skips_0(): """Tests skip-gram with min_skips = max_skips = 0.""" input_tensor = tf.constant([b"the", b"quick", b"brown"]) # If emit_self_as_target is False (default), output will be empty. tokens, labels = text.skip_gram_sample( input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False ) assert 0 == len(tokens) assert 0 == len(labels) # If emit_self_as_target is True, each token will be its own label. tokens, labels = text.skip_gram_sample( input_tensor, min_skips=0, max_skips=0, emit_self_as_target=True ) expected_tokens, expected_labels = _split_tokens_labels( [(b"the", b"the"), (b"quick", b"quick"), (b"brown", b"brown")] ) np.testing.assert_equal(np.asanyarray(expected_tokens), tokens.numpy()) np.testing.assert_equal(np.asanyarray(expected_labels), labels.numpy())
def test_skip_gram_sample_skips_0(self): """Tests skip-gram with min_skips = max_skips = 0.""" input_tensor = tf.constant([b"the", b"quick", b"brown"]) # If emit_self_as_target is False (default), output will be empty. tokens, labels = text.skip_gram_sample( input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False) with self.cached_session(): self.assertEqual(0, len(tokens)) self.assertEqual(0, len(labels)) # If emit_self_as_target is True, each token will be its own label. tokens, labels = text.skip_gram_sample( input_tensor, min_skips=0, max_skips=0, emit_self_as_target=True) expected_tokens, expected_labels = self._split_tokens_labels([ (b"the", b"the"), (b"quick", b"quick"), (b"brown", b"brown"), ]) self.assertAllEqual(expected_tokens, tokens) self.assertAllEqual(expected_labels, labels)
def test_skip_gram_sample_limit_exceeds(self): """Tests skip-gram when limit exceeds the length of the input.""" input_tensor = tf.constant([b"foo", b"the", b"quick", b"brown"]) tokens, labels = text.skip_gram_sample( input_tensor, min_skips=1, max_skips=1, start=1, limit=100) expected_tokens, expected_labels = self._split_tokens_labels([ (b"the", b"quick"), (b"quick", b"the"), (b"quick", b"brown"), (b"brown", b"quick"), ]) self.assertAllEqual(expected_tokens, tokens) self.assertAllEqual(expected_labels, labels)
def test_skip_gram_sample_start_limit(self): """Tests skip-gram over a limited portion of the input.""" input_tensor = tf.constant( [b"foo", b"the", b"quick", b"brown", b"bar"]) tokens, labels = text.skip_gram_sample( input_tensor, min_skips=1, max_skips=1, start=1, limit=3) expected_tokens, expected_labels = self._split_tokens_labels([ (b"the", b"quick"), (b"quick", b"the"), (b"quick", b"brown"), (b"brown", b"quick"), ]) self.assertAllEqual(expected_tokens, tokens) self.assertAllEqual(expected_labels, labels)
def test_skip_gram_sample_skips_exceed_length(self): """Tests skip-gram when min/max_skips exceed length of input.""" input_tensor = tf.constant([b"the", b"quick", b"brown"]) tokens, labels = text.skip_gram_sample( input_tensor, min_skips=100, max_skips=100) expected_tokens, expected_labels = self._split_tokens_labels([ (b"the", b"quick"), (b"the", b"brown"), (b"quick", b"the"), (b"quick", b"brown"), (b"brown", b"the"), (b"brown", b"quick"), ]) self.assertAllEqual(expected_tokens, tokens) self.assertAllEqual(expected_labels, labels)
def test_skip_gram_sample_limit_exceeds(): """Tests skip-gram when limit exceeds the length of the input.""" input_tensor = tf.constant([b"foo", b"the", b"quick", b"brown"]) tokens, labels = text.skip_gram_sample(input_tensor, min_skips=1, max_skips=1, start=1, limit=100) expected_tokens, expected_labels = _split_tokens_labels([ (b"the", b"quick"), (b"quick", b"the"), (b"quick", b"brown"), (b"brown", b"quick"), ]) np.testing.assert_equal(np.asanyarray(expected_tokens), tokens.numpy()) np.testing.assert_equal(np.asanyarray(expected_labels), labels.numpy())
def test_skip_gram_sample_skips_2(): """Tests skip-gram with min_skips = max_skips = 2.""" input_tensor = tf.constant([b"the", b"quick", b"brown", b"fox", b"jumps"]) tokens, labels = text.skip_gram_sample(input_tensor, min_skips=2, max_skips=2) expected_tokens, expected_labels = _split_tokens_labels([ (b"the", b"quick"), (b"the", b"brown"), (b"quick", b"the"), (b"quick", b"brown"), (b"quick", b"fox"), (b"brown", b"the"), (b"brown", b"quick"), (b"brown", b"fox"), (b"brown", b"jumps"), (b"fox", b"quick"), (b"fox", b"brown"), (b"fox", b"jumps"), (b"jumps", b"brown"), (b"jumps", b"fox"), ]) np.testing.assert_equal(np.asanyarray(expected_tokens), tokens.numpy()) np.testing.assert_equal(np.asanyarray(expected_labels), labels.numpy())
def test_skip_gram_sample_errors(self): """Tests various errors raised by skip_gram_sample().""" input_tensor = tf.constant([b"the", b"quick", b"brown"]) invalid_skips = ( # min_skips and max_skips must be >= 0. (-1, 2), (1, -2), # min_skips must be <= max_skips. (2, 1)) for min_skips, max_skips in invalid_skips: with self.assertRaises(tf.errors.InvalidArgumentError): text.skip_gram_sample( input_tensor, min_skips=min_skips, max_skips=max_skips) # Eager tensor must be rank 1 with self.assertRaises(tf.errors.InvalidArgumentError): invalid_tensor = tf.constant([[b"the"], [b"quick"], [b"brown"]]) text.skip_gram_sample(invalid_tensor) # vocab_freq_table must be provided if vocab_min_count, # vocab_subsampling, or corpus_size is specified. dummy_input = tf.constant([""]) with self.assertRaises(ValueError): text.skip_gram_sample( dummy_input, vocab_freq_table=None, vocab_min_count=1) with self.assertRaises(ValueError): text.skip_gram_sample( dummy_input, vocab_freq_table=None, vocab_subsampling=1e-5) with self.assertRaises(ValueError): text.skip_gram_sample( dummy_input, vocab_freq_table=None, corpus_size=100) with self.assertRaises(ValueError): text.skip_gram_sample( dummy_input, vocab_freq_table=None, vocab_subsampling=1e-5, corpus_size=100) # vocab_subsampling and corpus_size must both be present or absent. dummy_table = tf.lookup.StaticHashTable( tf.lookup.KeyValueTensorInitializer([b"foo"], [10]), -1) with self.assertRaises(ValueError): text.skip_gram_sample( dummy_input, vocab_freq_table=dummy_table, vocab_subsampling=None, corpus_size=100) with self.assertRaises(ValueError): text.skip_gram_sample( dummy_input, vocab_freq_table=dummy_table, vocab_subsampling=1e-5, corpus_size=None)
def test_skip_gram_sample_errors_v1(self): """Tests various errors raised by skip_gram_sample().""" # input_tensor must be of rank 1. with self.assertRaises(ValueError): invalid_tensor = tf.constant([[b"the"], [b"quick"], [b"brown"]]) text.skip_gram_sample(invalid_tensor)
def test_skip_gram_sample_errors_v1(): """Tests various errors raised by skip_gram_sample().""" # input_tensor must be of rank 1. with pytest.raises(tf.errors.InvalidArgumentError): invalid_tensor = tf.constant([[b"the"], [b"quick"], [b"brown"]]) text.skip_gram_sample(invalid_tensor)