def test_skip_gram_sample_random_skips_default_seed(self): """Tests outputs are still random when no op-level seed is specified.""" # This is needed since tests set a graph-level seed by default. We want to # explicitly avoid setting both graph-level seed and op-level seed, to # simulate behavior under non-test settings when the user doesn't provide a # seed to us. This results in random_seed.get_seed() returning None for both # seeds, forcing the C++ kernel to execute its default seed logic. random_seed.set_random_seed(None) # Uses an input tensor with 10 words, with possible skip ranges in [1, # 5]. Thus, the probability that two random samplings would result in the # same outputs is 1/5^10 ~ 1e-7 (aka the probability of this test being # flaky). input_tensor = constant_op.constant([str(x) for x in range(10)]) # Do not provide an op-level seed here! tokens_1, labels_1 = text.skip_gram_sample( input_tensor, min_skips=1, max_skips=5) tokens_2, labels_2 = text.skip_gram_sample( input_tensor, min_skips=1, max_skips=5) with self.cached_session() as sess: tokens_1_eval, labels_1_eval, tokens_2_eval, labels_2_eval = sess.run( [tokens_1, labels_1, tokens_2, labels_2]) if len(tokens_1_eval) == len(tokens_2_eval): self.assertNotEqual(tokens_1_eval.tolist(), tokens_2_eval.tolist()) if len(labels_1_eval) == len(labels_2_eval): self.assertNotEqual(labels_1_eval.tolist(), labels_2_eval.tolist())
def test_skip_gram_sample_emit_self(self): """Tests skip-gram with emit_self_as_target = True.""" input_tensor = constant_op.constant( [b"the", b"quick", b"brown", b"fox", b"jumps"]) tokens, labels = text.skip_gram_sample( input_tensor, min_skips=2, max_skips=2, emit_self_as_target=True) expected_tokens, expected_labels = self._split_tokens_labels([ (b"the", b"the"), (b"the", b"quick"), (b"the", b"brown"), (b"quick", b"the"), (b"quick", b"quick"), (b"quick", b"brown"), (b"quick", b"fox"), (b"brown", b"the"), (b"brown", b"quick"), (b"brown", b"brown"), (b"brown", b"fox"), (b"brown", b"jumps"), (b"fox", b"quick"), (b"fox", b"brown"), (b"fox", b"fox"), (b"fox", b"jumps"), (b"jumps", b"brown"), (b"jumps", b"fox"), (b"jumps", b"jumps"), ]) with self.cached_session(): self.assertAllEqual(expected_tokens, tokens.eval()) self.assertAllEqual(expected_labels, labels.eval())
def test_skip_gram_sample_batch(self): """Tests skip-gram with batching.""" input_tensor = constant_op.constant([b"the", b"quick", b"brown", b"fox"]) tokens, labels = text.skip_gram_sample( input_tensor, min_skips=1, max_skips=1, batch_size=3) expected_tokens, expected_labels = self._split_tokens_labels([ (b"the", b"quick"), (b"quick", b"the"), (b"quick", b"brown"), (b"brown", b"quick"), (b"brown", b"fox"), (b"fox", b"brown"), ]) with self.cached_session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) tokens_eval, labels_eval = sess.run([tokens, labels]) self.assertAllEqual(expected_tokens[:3], tokens_eval) self.assertAllEqual(expected_labels[:3], labels_eval) tokens_eval, labels_eval = sess.run([tokens, labels]) self.assertAllEqual(expected_tokens[3:6], tokens_eval) self.assertAllEqual(expected_labels[3:6], labels_eval) coord.request_stop() coord.join(threads)
def test_skip_gram_sample_random_skips(self): """Tests skip-gram with min_skips != max_skips, with random output.""" # The number of outputs is non-deterministic in this case, so set random # seed to help ensure the outputs remain constant for this test case. random_seed.set_random_seed(42) input_tensor = constant_op.constant( [b"the", b"quick", b"brown", b"fox", b"jumps", b"over"]) tokens, labels = text.skip_gram_sample( input_tensor, min_skips=1, max_skips=2, seed=9) expected_tokens, expected_labels = self._split_tokens_labels([ (b"the", b"quick"), (b"the", b"brown"), (b"quick", b"the"), (b"quick", b"brown"), (b"quick", b"fox"), (b"brown", b"the"), (b"brown", b"quick"), (b"brown", b"fox"), (b"brown", b"jumps"), (b"fox", b"brown"), (b"fox", b"jumps"), (b"jumps", b"fox"), (b"jumps", b"over"), (b"over", b"fox"), (b"over", b"jumps"), ]) with self.cached_session() as sess: tokens_eval, labels_eval = sess.run([tokens, labels]) self.assertAllEqual(expected_tokens, tokens_eval) self.assertAllEqual(expected_labels, labels_eval)
def test_skip_gram_sample_skips_0(self): """Tests skip-gram with min_skips = max_skips = 0.""" input_tensor = constant_op.constant([b"the", b"quick", b"brown"]) # If emit_self_as_target is False (default), output will be empty. tokens, labels = text.skip_gram_sample( input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False) with self.cached_session(): self.assertEqual(0, tokens.eval().size) self.assertEqual(0, labels.eval().size) # If emit_self_as_target is True, each token will be its own label. tokens, labels = text.skip_gram_sample( input_tensor, min_skips=0, max_skips=0, emit_self_as_target=True) expected_tokens, expected_labels = self._split_tokens_labels([ (b"the", b"the"), (b"quick", b"quick"), (b"brown", b"brown"), ]) with self.cached_session(): self.assertAllEqual(expected_tokens, tokens.eval()) self.assertAllEqual(expected_labels, labels.eval())
def test_skip_gram_sample_non_string_input(self): """Tests skip-gram with non-string input.""" input_tensor = constant_op.constant([1, 2, 3], dtype=dtypes.int16) tokens, labels = text.skip_gram_sample( input_tensor, min_skips=1, max_skips=1) expected_tokens, expected_labels = self._split_tokens_labels([ (1, 2), (2, 1), (2, 3), (3, 2), ]) with self.cached_session(): self.assertAllEqual(expected_tokens, tokens.eval()) self.assertAllEqual(expected_labels, labels.eval())
def test_skip_gram_sample_limit_exceeds(self): """Tests skip-gram when limit exceeds the length of the input.""" input_tensor = constant_op.constant([b"foo", b"the", b"quick", b"brown"]) tokens, labels = text.skip_gram_sample( input_tensor, min_skips=1, max_skips=1, start=1, limit=100) expected_tokens, expected_labels = self._split_tokens_labels([ (b"the", b"quick"), (b"quick", b"the"), (b"quick", b"brown"), (b"brown", b"quick"), ]) with self.cached_session(): self.assertAllEqual(expected_tokens, tokens.eval()) self.assertAllEqual(expected_labels, labels.eval())
def test_skip_gram_sample_errors(self): """Tests various errors raised by skip_gram_sample().""" input_tensor = constant_op.constant([b"the", b"quick", b"brown"]) invalid_skips = ( # min_skips and max_skips must be >= 0. (-1, 2), (1, -2), # min_skips must be <= max_skips. (2, 1)) for min_skips, max_skips in invalid_skips: tokens, labels = text.skip_gram_sample( input_tensor, min_skips=min_skips, max_skips=max_skips) with self.cached_session() as sess, self.assertRaises( errors.InvalidArgumentError): sess.run([tokens, labels]) # input_tensor must be of rank 1. with self.assertRaises(ValueError): invalid_tensor = constant_op.constant([[b"the"], [b"quick"], [b"brown"]]) text.skip_gram_sample(invalid_tensor) # vocab_freq_table must be provided if vocab_min_count, vocab_subsampling, # or corpus_size is specified. dummy_input = constant_op.constant([""]) with self.assertRaises(ValueError): text.skip_gram_sample( dummy_input, vocab_freq_table=None, vocab_min_count=1) with self.assertRaises(ValueError): text.skip_gram_sample( dummy_input, vocab_freq_table=None, vocab_subsampling=1e-5) with self.assertRaises(ValueError): text.skip_gram_sample(dummy_input, vocab_freq_table=None, corpus_size=100) with self.assertRaises(ValueError): text.skip_gram_sample( dummy_input, vocab_freq_table=None, vocab_subsampling=1e-5, corpus_size=100) # vocab_subsampling and corpus_size must both be present or absent. dummy_table = lookup.HashTable( lookup.KeyValueTensorInitializer([b"foo"], [10]), -1) with self.assertRaises(ValueError): text.skip_gram_sample( dummy_input, vocab_freq_table=dummy_table, vocab_subsampling=None, corpus_size=100) with self.assertRaises(ValueError): text.skip_gram_sample( dummy_input, vocab_freq_table=dummy_table, vocab_subsampling=1e-5, corpus_size=None)