示例#1
0
  def test_filter_input_subsample_vocab(self):
    """Tests input filtering based on vocab subsampling."""
    # The outputs are non-deterministic, so set random seed to help ensure that
    # the outputs remain constant for testing.
    random_seed.set_random_seed(42)

    input_tensor = constant_op.constant([
        # keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57.
        b"the",
        b"answer",  # Not in vocab. (Always discarded)
        b"to",  # keep_prob = 0.75.
        b"life",  # keep_prob > 1. (Always kept)
        b"and",  # keep_prob = 0.48.
        b"universe"  # Below vocab threshold of 3. (Always discarded)
    ])
    keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"])
    values = constant_op.constant([40, 8, 30, 20, 2], dtypes.int64)
    vocab_freq_table = lookup.HashTable(
        lookup.KeyValueTensorInitializer(keys, values), -1)

    with self.cached_session():
      vocab_freq_table.initializer.run()
      output = skip_gram_ops._filter_input(
          input_tensor=input_tensor,
          vocab_freq_table=vocab_freq_table,
          vocab_min_count=3,
          vocab_subsampling=0.05,
          corpus_size=math_ops.reduce_sum(values),
          seed=9)
      self.assertAllEqual([b"the", b"to", b"life", b"and"], output.eval())
示例#2
0
  def test_skip_gram_sample_errors(self):
    """Tests various errors raised by skip_gram_sample()."""
    input_tensor = constant_op.constant([b"the", b"quick", b"brown"])

    invalid_skips = (
        # min_skips and max_skips must be >= 0.
        (-1, 2),
        (1, -2),
        # min_skips must be <= max_skips.
        (2, 1))
    for min_skips, max_skips in invalid_skips:
      tokens, labels = text.skip_gram_sample(
          input_tensor, min_skips=min_skips, max_skips=max_skips)
      with self.cached_session() as sess, self.assertRaises(
          errors.InvalidArgumentError):
        sess.run([tokens, labels])

    # input_tensor must be of rank 1.
    with self.assertRaises(ValueError):
      invalid_tensor = constant_op.constant([[b"the"], [b"quick"], [b"brown"]])
      text.skip_gram_sample(invalid_tensor)

    # vocab_freq_table must be provided if vocab_min_count, vocab_subsampling,
    # or corpus_size is specified.
    dummy_input = constant_op.constant([""])
    with self.assertRaises(ValueError):
      text.skip_gram_sample(
          dummy_input, vocab_freq_table=None, vocab_min_count=1)
    with self.assertRaises(ValueError):
      text.skip_gram_sample(
          dummy_input, vocab_freq_table=None, vocab_subsampling=1e-5)
    with self.assertRaises(ValueError):
      text.skip_gram_sample(dummy_input, vocab_freq_table=None, corpus_size=100)
    with self.assertRaises(ValueError):
      text.skip_gram_sample(
          dummy_input,
          vocab_freq_table=None,
          vocab_subsampling=1e-5,
          corpus_size=100)

    # vocab_subsampling and corpus_size must both be present or absent.
    dummy_table = lookup.HashTable(
        lookup.KeyValueTensorInitializer([b"foo"], [10]), -1)
    with self.assertRaises(ValueError):
      text.skip_gram_sample(
          dummy_input,
          vocab_freq_table=dummy_table,
          vocab_subsampling=None,
          corpus_size=100)
    with self.assertRaises(ValueError):
      text.skip_gram_sample(
          dummy_input,
          vocab_freq_table=dummy_table,
          vocab_subsampling=1e-5,
          corpus_size=None)
示例#3
0
 def testMapCaptureLookupTable(self):
     default_val = -1
     keys = constant_op.constant(['brain', 'salad', 'surgery'])
     values = constant_op.constant([0, 1, 2], dtypes.int64)
     table = lookup.HashTable(
         lookup.KeyValueTensorInitializer(keys, values), default_val)
     dataset = dataset_ops.Dataset.from_tensor_slices(
         ['brain', 'salad', 'surgery'])
     dataset = dataset.map(table.lookup)
     it = datasets.Iterator(dataset)
     got = [x.numpy() for x in it]
     self.assertAllEqual([0, 1, 2], got)
示例#4
0
  def test_filter_input_filter_vocab(self):
    """Tests input filtering based on vocab frequency table and thresholds."""
    input_tensor = constant_op.constant(
        [b"the", b"answer", b"to", b"life", b"and", b"universe"])
    keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"])
    values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
    vocab_freq_table = lookup.HashTable(
        lookup.KeyValueTensorInitializer(keys, values), -1)

    with self.cached_session():
      vocab_freq_table.initializer.run()

      # No vocab_freq_table specified - output should be the same as input.
      no_table_output = skip_gram_ops._filter_input(
          input_tensor=input_tensor,
          vocab_freq_table=None,
          vocab_min_count=None,
          vocab_subsampling=None,
          corpus_size=None,
          seed=None)
      self.assertAllEqual(input_tensor.eval(), no_table_output.eval())

      # vocab_freq_table specified, but no vocab_min_count - output should have
      # filtered out tokens not in the table (b"answer").
      table_output = skip_gram_ops._filter_input(
          input_tensor=input_tensor,
          vocab_freq_table=vocab_freq_table,
          vocab_min_count=None,
          vocab_subsampling=None,
          corpus_size=None,
          seed=None)
      self.assertAllEqual([b"the", b"to", b"life", b"and", b"universe"],
                          table_output.eval())

      # vocab_freq_table and vocab_min_count specified - output should have
      # filtered out tokens whose frequencies are below the threshold
      # (b"and": 0, b"life": 1).
      threshold_output = skip_gram_ops._filter_input(
          input_tensor=input_tensor,
          vocab_freq_table=vocab_freq_table,
          vocab_min_count=2,
          vocab_subsampling=None,
          corpus_size=None,
          seed=None)
      self.assertAllEqual([b"the", b"to", b"universe"], threshold_output.eval())
示例#5
0
def skip_gram_sample_with_text_vocab(input_tensor,
                                     vocab_freq_file,
                                     vocab_token_index=0,
                                     vocab_token_dtype=dtypes.string,
                                     vocab_freq_index=1,
                                     vocab_freq_dtype=dtypes.float64,
                                     vocab_delimiter=",",
                                     vocab_min_count=0,
                                     vocab_subsampling=None,
                                     corpus_size=None,
                                     min_skips=1,
                                     max_skips=5,
                                     start=0,
                                     limit=-1,
                                     emit_self_as_target=False,
                                     batch_size=None,
                                     batch_capacity=None,
                                     seed=None,
                                     name=None):
    """Skip-gram sampling with a text vocabulary file.

  Wrapper around `skip_gram_sample()` for use with a text vocabulary file. The
  vocabulary file is expected to be a plain-text file, with lines of
  `vocab_delimiter`-separated columns. The `vocab_token_index` column should
  contain the vocabulary term, while the `vocab_freq_index` column should
  contain the number of times that term occurs in the corpus. For example, with
  a text vocabulary file of:

    ```
    bonjour,fr,42
    hello,en,777
    hola,es,99
    ```

  You should set `vocab_delimiter=","`, `vocab_token_index=0`, and
  `vocab_freq_index=2`.

  See `skip_gram_sample()` documentation for more details about the skip-gram
  sampling process.

  Args:
    input_tensor: A rank-1 `Tensor` from which to generate skip-gram candidates.
    vocab_freq_file: `string` specifying full file path to the text vocab file.
    vocab_token_index: `int` specifying which column in the text vocab file
      contains the tokens.
    vocab_token_dtype: `DType` specifying the format of the tokens in the text
      vocab file.
    vocab_freq_index: `int` specifying which column in the text vocab file
      contains the frequency counts of the tokens.
    vocab_freq_dtype: `DType` specifying the format of the frequency counts in
      the text vocab file.
    vocab_delimiter: `string` specifying the delimiter used in the text vocab
      file.
    vocab_min_count: `int`, `float`, or scalar `Tensor` specifying
      minimum frequency threshold (from `vocab_freq_file`) for a token to be
      kept in `input_tensor`. This should correspond with `vocab_freq_dtype`.
    vocab_subsampling: (Optional) `float` specifying frequency proportion
      threshold for tokens from `input_tensor`. Tokens that occur more
      frequently will be randomly down-sampled. Reasonable starting values may
      be around 1e-3 or 1e-5. See Eq. 5 in http://arxiv.org/abs/1310.4546 for
      more details.
    corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the
      total number of tokens in the corpus (e.g., sum of all the frequency
      counts of `vocab_freq_file`). Used with `vocab_subsampling` for
      down-sampling frequently occurring tokens. If this is specified,
      `vocab_freq_file` and `vocab_subsampling` must also be specified.
      If `corpus_size` is needed but not supplied, then it will be calculated
      from `vocab_freq_file`. You might want to supply your own value if you
      have already eliminated infrequent tokens from your vocabulary files
      (where frequency < vocab_min_count) to save memory in the internal token
      lookup table. Otherwise, the unused tokens' variables will waste memory.
      The user-supplied `corpus_size` value must be greater than or equal to the
      sum of all the frequency counts of `vocab_freq_file`.
    min_skips: `int` or scalar `Tensor` specifying the minimum window size to
      randomly use for each token. Must be >= 0 and <= `max_skips`. If
      `min_skips` and `max_skips` are both 0, the only label outputted will be
      the token itself.
    max_skips: `int` or scalar `Tensor` specifying the maximum window size to
      randomly use for each token. Must be >= 0.
    start: `int` or scalar `Tensor` specifying the position in `input_tensor`
      from which to start generating skip-gram candidates.
    limit: `int` or scalar `Tensor` specifying the maximum number of elements in
      `input_tensor` to use in generating skip-gram candidates. -1 means to use
      the rest of the `Tensor` after `start`.
    emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit
      each token as a label for itself.
    batch_size: (Optional) `int` specifying batch size of returned `Tensors`.
    batch_capacity: (Optional) `int` specifying batch capacity for the queue
      used for batching returned `Tensors`. Only has an effect if
      `batch_size` > 0. Defaults to 100 * `batch_size` if not specified.
    seed: (Optional) `int` used to create a random seed for window size and
      subsampling. See
      [`set_random_seed`](../../g3doc/python/constant_op.md#set_random_seed)
      for behavior.
    name: (Optional) A `string` name or a name scope for the operations.

  Returns:
    A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of
    rank-1 and has the same type as `input_tensor`. The `Tensors` will be of
    length `batch_size`; if `batch_size` is not specified, they will be of
    random length, though they will be in sync with each other as long as they
    are evaluated together.

  Raises:
    ValueError: If `vocab_token_index` or `vocab_freq_index` is less than 0 or
      exceeds the number of columns in `vocab_freq_file`. If `vocab_token_index`
      and `vocab_freq_index` are both set to the same column. If any token in
      `vocab_freq_file` has a negative frequency.
  """

    if vocab_token_index < 0 or vocab_freq_index < 0:
        raise ValueError(
            "vocab_token_index={} and vocab_freq_index={} must both be >= 0.".
            format(vocab_token_index, vocab_freq_index))
    if vocab_token_index == vocab_freq_index:
        raise ValueError(
            "vocab_token_index and vocab_freq_index should be different, but are "
            "both {}.".format(vocab_token_index))

    # Iterates through the vocab file and calculates the number of vocab terms as
    # well as the total corpus size (by summing the frequency counts of all the
    # vocab terms).
    calculated_corpus_size = 0.0
    vocab_size = 0
    with gfile.GFile(vocab_freq_file, mode="r") as f:
        reader = csv.reader(f, delimiter=vocab_delimiter)
        for row in reader:
            if vocab_token_index >= len(row) or vocab_freq_index >= len(row):
                raise ValueError(
                    "Row in vocab file only has {} columns, so vocab_token_index={} or "
                    "vocab_freq_index={} is out of bounds. Row content: {}".
                    format(len(row), vocab_token_index, vocab_freq_index, row))
            vocab_size += 1
            freq = vocab_freq_dtype.as_numpy_dtype(row[vocab_freq_index])
            if freq < 0:
                raise ValueError(
                    "Row in vocab file has negative frequency of {}. Row content: {}"
                    .format(freq, row))
            # Note: tokens whose frequencies are below vocab_min_count will still
            # contribute to the total corpus size used for vocab subsampling.
            calculated_corpus_size += freq

    if not corpus_size:
        corpus_size = calculated_corpus_size
    elif calculated_corpus_size - corpus_size > 1e-6:
        raise ValueError(
            "`corpus_size`={} must be greater than or equal to the sum of all the "
            "frequency counts ({}) of `vocab_freq_file` ({}).".format(
                corpus_size, calculated_corpus_size, vocab_freq_file))

    vocab_freq_table = lookup.HashTable(
        lookup.TextFileInitializer(filename=vocab_freq_file,
                                   key_dtype=vocab_token_dtype,
                                   key_index=vocab_token_index,
                                   value_dtype=vocab_freq_dtype,
                                   value_index=vocab_freq_index,
                                   vocab_size=vocab_size,
                                   delimiter=vocab_delimiter),
        # For vocab terms not in vocab file, use a default value of -1.
        default_value=-1)

    return skip_gram_sample(
        input_tensor,
        min_skips=min_skips,
        max_skips=max_skips,
        start=start,
        limit=limit,
        emit_self_as_target=emit_self_as_target,
        vocab_freq_table=vocab_freq_table,
        vocab_min_count=vocab_min_count,
        vocab_subsampling=vocab_subsampling,
        # corpus_size is not used unless vocab_subsampling is specified.
        corpus_size=None if vocab_subsampling is None else corpus_size,
        batch_size=batch_size,
        batch_capacity=batch_capacity,
        seed=seed,
        name=name)