def test_skip_gram_sample_with_text_vocab_errors(self): """Tests various errors raised by skip_gram_sample_with_text_vocab().""" dummy_input = tf.constant([""]) vocab_freq_file = self._make_text_vocab_freq_file() invalid_indices = ( # vocab_token_index can't be negative. (-1, 0), # vocab_freq_index can't be negative. (0, -1), # vocab_token_index can't be equal to vocab_freq_index. (0, 0), (1, 1), # vocab_freq_file only has two columns. (0, 2), (2, 0)) for vocab_token_index, vocab_freq_index in invalid_indices: with self.assertRaises(ValueError): text.skip_gram_sample_with_text_vocab( input_tensor=dummy_input, vocab_freq_file=vocab_freq_file, vocab_token_index=vocab_token_index, vocab_freq_index=vocab_freq_index)
def test_skip_gram_sample_with_text_vocab_filter_vocab(self): """Tests skip-gram sampling with text vocab and freq threshold filtering.""" input_tensor = tf.constant([ b"the", b"answer", # Will be filtered before candidate generation. b"to", b"life", b"and", b"universe" # Will be filtered before candidate generation. ]) # b"answer" is not in vocab file, and b"universe"'s frequency is below # threshold of 3. vocab_freq_file = self._make_text_vocab_freq_file() tokens, labels = text.skip_gram_sample_with_text_vocab( input_tensor=input_tensor, vocab_freq_file=vocab_freq_file, vocab_token_index=0, vocab_freq_index=1, vocab_min_count=3, min_skips=1, max_skips=1) expected_tokens, expected_labels = self._split_tokens_labels([ (b"the", b"to"), (b"to", b"the"), (b"to", b"life"), (b"life", b"to"), (b"life", b"and"), (b"and", b"life"), ]) self.assertAllEqual(expected_tokens, tokens) self.assertAllEqual(expected_labels, labels)
def _skip_gram_sample_with_text_vocab_errors(vocab_freq_file): dummy_input = tf.constant([""]) invalid_indices = ( # vocab_token_index can't be negative. (-1, 0), # vocab_freq_index can't be negative. (0, -1), # vocab_token_index can't be equal to vocab_freq_index. (0, 0), (1, 1), # vocab_freq_file only has two columns. (0, 2), (2, 0), ) for vocab_token_index, vocab_freq_index in invalid_indices: with pytest.raises(ValueError): text.skip_gram_sample_with_text_vocab( input_tensor=dummy_input, vocab_freq_file=vocab_freq_file, vocab_token_index=vocab_token_index, vocab_freq_index=vocab_freq_index, )
def _text_vocab_subsample_vocab_helper(self, vocab_freq_file, vocab_min_count, vocab_freq_dtype, corpus_size=None): # The outputs are non-deterministic, so set random seed to help ensure # that the outputs remain constant for testing. tf.compat.v1.set_random_seed(42) input_tensor = tf.constant([ # keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57. b"the", b"answer", # Not in vocab. (Always discarded) b"to", # keep_prob = 0.75. b"life", # keep_prob > 1. (Always kept) b"and", # keep_prob = 0.48. b"universe", # Below vocab threshold of 3. (Always discarded) ]) # keep_prob calculated from vocab file with relative frequencies of: # and: 40 # life: 8 # the: 30 # to: 20 # universe: 2 tokens, labels = text.skip_gram_sample_with_text_vocab( input_tensor=input_tensor, vocab_freq_file=vocab_freq_file, vocab_token_index=0, vocab_freq_index=1, vocab_freq_dtype=tf.dtypes.float64, vocab_min_count=vocab_min_count, vocab_subsampling=0.05, corpus_size=corpus_size, min_skips=1, max_skips=1, seed=123, ) expected_tokens, expected_labels = self._split_tokens_labels([ (b"the", b"to"), (b"to", b"the"), (b"to", b"life"), (b"life", b"to"), ]) self.assertAllEqual(expected_tokens, tokens) self.assertAllEqual(expected_labels, labels)
def test_skip_gram_sample_with_text_vocab_filter_vocab(): """Tests skip-gram sampling with text vocab and freq threshold filtering.""" input_tensor = tf.constant( [ b"the", b"answer", # Will be filtered before candidate generation. b"to", b"life", b"and", b"universe", # Will be filtered before candidate generation. ] ) # b"answer" is not in vocab file, and b"universe"'s frequency is below # threshold of 3. with tempfile.TemporaryDirectory() as tmp_dir: vocab_freq_file = _make_text_vocab_freq_file(tmp_dir) tokens, labels = text.skip_gram_sample_with_text_vocab( input_tensor=input_tensor, vocab_freq_file=vocab_freq_file, vocab_token_index=0, vocab_freq_index=1, vocab_min_count=3, min_skips=1, max_skips=1, ) expected_tokens, expected_labels = _split_tokens_labels( [ (b"the", b"to"), (b"to", b"the"), (b"to", b"life"), (b"life", b"to"), (b"life", b"and"), (b"and", b"life"), ] ) np.testing.assert_equal(np.asanyarray(expected_tokens), tokens.numpy()) np.testing.assert_equal(np.asanyarray(expected_labels), labels.numpy())