def test_filter_input_subsample_vocab(self): """Tests input filtering based on vocab subsampling.""" # The outputs are non-deterministic, so set random seed to help ensure that # the outputs remain constant for testing. random_seed.set_random_seed(42) input_tensor = constant_op.constant([ # keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57. b"the", b"answer", # Not in vocab. (Always discarded) b"to", # keep_prob = 0.75. b"life", # keep_prob > 1. (Always kept) b"and", # keep_prob = 0.48. b"universe" # Below vocab threshold of 3. (Always discarded) ]) keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"]) values = constant_op.constant([40, 8, 30, 20, 2], dtypes.int64) vocab_freq_table = lookup.HashTable( lookup.KeyValueTensorInitializer(keys, values), -1) with self.cached_session(): vocab_freq_table.initializer.run() output = skip_gram_ops._filter_input( input_tensor=input_tensor, vocab_freq_table=vocab_freq_table, vocab_min_count=3, vocab_subsampling=0.05, corpus_size=math_ops.reduce_sum(values), seed=9) self.assertAllEqual([b"the", b"to", b"life", b"and"], output.eval())
def test_skip_gram_sample_errors(self): """Tests various errors raised by skip_gram_sample().""" input_tensor = constant_op.constant([b"the", b"quick", b"brown"]) invalid_skips = ( # min_skips and max_skips must be >= 0. (-1, 2), (1, -2), # min_skips must be <= max_skips. (2, 1)) for min_skips, max_skips in invalid_skips: tokens, labels = text.skip_gram_sample( input_tensor, min_skips=min_skips, max_skips=max_skips) with self.cached_session() as sess, self.assertRaises( errors.InvalidArgumentError): sess.run([tokens, labels]) # input_tensor must be of rank 1. with self.assertRaises(ValueError): invalid_tensor = constant_op.constant([[b"the"], [b"quick"], [b"brown"]]) text.skip_gram_sample(invalid_tensor) # vocab_freq_table must be provided if vocab_min_count, vocab_subsampling, # or corpus_size is specified. dummy_input = constant_op.constant([""]) with self.assertRaises(ValueError): text.skip_gram_sample( dummy_input, vocab_freq_table=None, vocab_min_count=1) with self.assertRaises(ValueError): text.skip_gram_sample( dummy_input, vocab_freq_table=None, vocab_subsampling=1e-5) with self.assertRaises(ValueError): text.skip_gram_sample(dummy_input, vocab_freq_table=None, corpus_size=100) with self.assertRaises(ValueError): text.skip_gram_sample( dummy_input, vocab_freq_table=None, vocab_subsampling=1e-5, corpus_size=100) # vocab_subsampling and corpus_size must both be present or absent. dummy_table = lookup.HashTable( lookup.KeyValueTensorInitializer([b"foo"], [10]), -1) with self.assertRaises(ValueError): text.skip_gram_sample( dummy_input, vocab_freq_table=dummy_table, vocab_subsampling=None, corpus_size=100) with self.assertRaises(ValueError): text.skip_gram_sample( dummy_input, vocab_freq_table=dummy_table, vocab_subsampling=1e-5, corpus_size=None)
def testMapCaptureLookupTable(self): default_val = -1 keys = constant_op.constant(['brain', 'salad', 'surgery']) values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup.HashTable( lookup.KeyValueTensorInitializer(keys, values), default_val) dataset = dataset_ops.Dataset.from_tensor_slices( ['brain', 'salad', 'surgery']) dataset = dataset.map(table.lookup) it = datasets.Iterator(dataset) got = [x.numpy() for x in it] self.assertAllEqual([0, 1, 2], got)
def test_filter_input_filter_vocab(self): """Tests input filtering based on vocab frequency table and thresholds.""" input_tensor = constant_op.constant( [b"the", b"answer", b"to", b"life", b"and", b"universe"]) keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"]) values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64) vocab_freq_table = lookup.HashTable( lookup.KeyValueTensorInitializer(keys, values), -1) with self.cached_session(): vocab_freq_table.initializer.run() # No vocab_freq_table specified - output should be the same as input. no_table_output = skip_gram_ops._filter_input( input_tensor=input_tensor, vocab_freq_table=None, vocab_min_count=None, vocab_subsampling=None, corpus_size=None, seed=None) self.assertAllEqual(input_tensor.eval(), no_table_output.eval()) # vocab_freq_table specified, but no vocab_min_count - output should have # filtered out tokens not in the table (b"answer"). table_output = skip_gram_ops._filter_input( input_tensor=input_tensor, vocab_freq_table=vocab_freq_table, vocab_min_count=None, vocab_subsampling=None, corpus_size=None, seed=None) self.assertAllEqual([b"the", b"to", b"life", b"and", b"universe"], table_output.eval()) # vocab_freq_table and vocab_min_count specified - output should have # filtered out tokens whose frequencies are below the threshold # (b"and": 0, b"life": 1). threshold_output = skip_gram_ops._filter_input( input_tensor=input_tensor, vocab_freq_table=vocab_freq_table, vocab_min_count=2, vocab_subsampling=None, corpus_size=None, seed=None) self.assertAllEqual([b"the", b"to", b"universe"], threshold_output.eval())
def skip_gram_sample_with_text_vocab(input_tensor, vocab_freq_file, vocab_token_index=0, vocab_token_dtype=dtypes.string, vocab_freq_index=1, vocab_freq_dtype=dtypes.float64, vocab_delimiter=",", vocab_min_count=0, vocab_subsampling=None, corpus_size=None, min_skips=1, max_skips=5, start=0, limit=-1, emit_self_as_target=False, batch_size=None, batch_capacity=None, seed=None, name=None): """Skip-gram sampling with a text vocabulary file. Wrapper around `skip_gram_sample()` for use with a text vocabulary file. The vocabulary file is expected to be a plain-text file, with lines of `vocab_delimiter`-separated columns. The `vocab_token_index` column should contain the vocabulary term, while the `vocab_freq_index` column should contain the number of times that term occurs in the corpus. For example, with a text vocabulary file of: ``` bonjour,fr,42 hello,en,777 hola,es,99 ``` You should set `vocab_delimiter=","`, `vocab_token_index=0`, and `vocab_freq_index=2`. See `skip_gram_sample()` documentation for more details about the skip-gram sampling process. Args: input_tensor: A rank-1 `Tensor` from which to generate skip-gram candidates. vocab_freq_file: `string` specifying full file path to the text vocab file. vocab_token_index: `int` specifying which column in the text vocab file contains the tokens. vocab_token_dtype: `DType` specifying the format of the tokens in the text vocab file. vocab_freq_index: `int` specifying which column in the text vocab file contains the frequency counts of the tokens. vocab_freq_dtype: `DType` specifying the format of the frequency counts in the text vocab file. vocab_delimiter: `string` specifying the delimiter used in the text vocab file. vocab_min_count: `int`, `float`, or scalar `Tensor` specifying minimum frequency threshold (from `vocab_freq_file`) for a token to be kept in `input_tensor`. This should correspond with `vocab_freq_dtype`. vocab_subsampling: (Optional) `float` specifying frequency proportion threshold for tokens from `input_tensor`. Tokens that occur more frequently will be randomly down-sampled. Reasonable starting values may be around 1e-3 or 1e-5. See Eq. 5 in http://arxiv.org/abs/1310.4546 for more details. corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the total number of tokens in the corpus (e.g., sum of all the frequency counts of `vocab_freq_file`). Used with `vocab_subsampling` for down-sampling frequently occurring tokens. If this is specified, `vocab_freq_file` and `vocab_subsampling` must also be specified. If `corpus_size` is needed but not supplied, then it will be calculated from `vocab_freq_file`. You might want to supply your own value if you have already eliminated infrequent tokens from your vocabulary files (where frequency < vocab_min_count) to save memory in the internal token lookup table. Otherwise, the unused tokens' variables will waste memory. The user-supplied `corpus_size` value must be greater than or equal to the sum of all the frequency counts of `vocab_freq_file`. min_skips: `int` or scalar `Tensor` specifying the minimum window size to randomly use for each token. Must be >= 0 and <= `max_skips`. If `min_skips` and `max_skips` are both 0, the only label outputted will be the token itself. max_skips: `int` or scalar `Tensor` specifying the maximum window size to randomly use for each token. Must be >= 0. start: `int` or scalar `Tensor` specifying the position in `input_tensor` from which to start generating skip-gram candidates. limit: `int` or scalar `Tensor` specifying the maximum number of elements in `input_tensor` to use in generating skip-gram candidates. -1 means to use the rest of the `Tensor` after `start`. emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit each token as a label for itself. batch_size: (Optional) `int` specifying batch size of returned `Tensors`. batch_capacity: (Optional) `int` specifying batch capacity for the queue used for batching returned `Tensors`. Only has an effect if `batch_size` > 0. Defaults to 100 * `batch_size` if not specified. seed: (Optional) `int` used to create a random seed for window size and subsampling. See [`set_random_seed`](../../g3doc/python/constant_op.md#set_random_seed) for behavior. name: (Optional) A `string` name or a name scope for the operations. Returns: A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of rank-1 and has the same type as `input_tensor`. The `Tensors` will be of length `batch_size`; if `batch_size` is not specified, they will be of random length, though they will be in sync with each other as long as they are evaluated together. Raises: ValueError: If `vocab_token_index` or `vocab_freq_index` is less than 0 or exceeds the number of columns in `vocab_freq_file`. If `vocab_token_index` and `vocab_freq_index` are both set to the same column. If any token in `vocab_freq_file` has a negative frequency. """ if vocab_token_index < 0 or vocab_freq_index < 0: raise ValueError( "vocab_token_index={} and vocab_freq_index={} must both be >= 0.". format(vocab_token_index, vocab_freq_index)) if vocab_token_index == vocab_freq_index: raise ValueError( "vocab_token_index and vocab_freq_index should be different, but are " "both {}.".format(vocab_token_index)) # Iterates through the vocab file and calculates the number of vocab terms as # well as the total corpus size (by summing the frequency counts of all the # vocab terms). calculated_corpus_size = 0.0 vocab_size = 0 with gfile.GFile(vocab_freq_file, mode="r") as f: reader = csv.reader(f, delimiter=vocab_delimiter) for row in reader: if vocab_token_index >= len(row) or vocab_freq_index >= len(row): raise ValueError( "Row in vocab file only has {} columns, so vocab_token_index={} or " "vocab_freq_index={} is out of bounds. Row content: {}". format(len(row), vocab_token_index, vocab_freq_index, row)) vocab_size += 1 freq = vocab_freq_dtype.as_numpy_dtype(row[vocab_freq_index]) if freq < 0: raise ValueError( "Row in vocab file has negative frequency of {}. Row content: {}" .format(freq, row)) # Note: tokens whose frequencies are below vocab_min_count will still # contribute to the total corpus size used for vocab subsampling. calculated_corpus_size += freq if not corpus_size: corpus_size = calculated_corpus_size elif calculated_corpus_size - corpus_size > 1e-6: raise ValueError( "`corpus_size`={} must be greater than or equal to the sum of all the " "frequency counts ({}) of `vocab_freq_file` ({}).".format( corpus_size, calculated_corpus_size, vocab_freq_file)) vocab_freq_table = lookup.HashTable( lookup.TextFileInitializer(filename=vocab_freq_file, key_dtype=vocab_token_dtype, key_index=vocab_token_index, value_dtype=vocab_freq_dtype, value_index=vocab_freq_index, vocab_size=vocab_size, delimiter=vocab_delimiter), # For vocab terms not in vocab file, use a default value of -1. default_value=-1) return skip_gram_sample( input_tensor, min_skips=min_skips, max_skips=max_skips, start=start, limit=limit, emit_self_as_target=emit_self_as_target, vocab_freq_table=vocab_freq_table, vocab_min_count=vocab_min_count, vocab_subsampling=vocab_subsampling, # corpus_size is not used unless vocab_subsampling is specified. corpus_size=None if vocab_subsampling is None else corpus_size, batch_size=batch_size, batch_capacity=batch_capacity, seed=seed, name=name)