예제 #1
0
    def test_filter_input_subsample_vocab(self):
        """Tests input filtering based on vocab subsampling."""
        # The outputs are non-deterministic, so set random seed to help ensure that
        # the outputs remain constant for testing.
        random_seed.set_random_seed(42)

        input_tensor = constant_op.constant([
            # keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57.
            b"the",
            b"answer",  # Not in vocab. (Always discarded)
            b"to",  # keep_prob = 0.75.
            b"life",  # keep_prob > 1. (Always kept)
            b"and",  # keep_prob = 0.48.
            b"universe"  # Below vocab threshold of 3. (Always discarded)
        ])
        keys = constant_op.constant(
            [b"and", b"life", b"the", b"to", b"universe"])
        values = constant_op.constant([40, 8, 30, 20, 2], dtypes.int64)
        vocab_freq_table = lookup.HashTable(
            lookup.KeyValueTensorInitializer(keys, values), -1)

        with self.test_session():
            vocab_freq_table.init.run()
            output = skip_gram_ops._filter_input(
                input_tensor=input_tensor,
                vocab_freq_table=vocab_freq_table,
                vocab_min_count=3,
                vocab_subsampling=0.05,
                corpus_size=math_ops.reduce_sum(values),
                seed=9)
            self.assertAllEqual([b"the", b"to", b"life", b"and"],
                                output.eval())
예제 #2
0
  def test_filter_input_subsample_vocab(self):
    """Tests input filtering based on vocab subsampling."""
    # The outputs are non-deterministic, so set random seed to help ensure that
    # the outputs remain constant for testing.
    random_seed.set_random_seed(42)

    input_tensor = constant_op.constant([
        # keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57.
        b"the",
        b"answer",  # Not in vocab. (Always discarded)
        b"to",  # keep_prob = 0.75.
        b"life",  # keep_prob > 1. (Always kept)
        b"and",  # keep_prob = 0.48.
        b"universe"  # Below vocab threshold of 3. (Always discarded)
    ])
    keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"])
    values = constant_op.constant([40, 8, 30, 20, 2], dtypes.int64)
    vocab_freq_table = lookup.HashTable(
        lookup.KeyValueTensorInitializer(keys, values), -1)

    with self.test_session():
      vocab_freq_table.init.run()
      output = skip_gram_ops._filter_input(
          input_tensor=input_tensor,
          vocab_freq_table=vocab_freq_table,
          vocab_min_count=3,
          vocab_subsampling=0.05,
          corpus_size=math_ops.reduce_sum(values),
          seed=9)
      self.assertAllEqual([b"the", b"to", b"life", b"and"], output.eval())
예제 #3
0
    def test_filter_input_filter_vocab(self):
        """Tests input filtering based on vocab frequency table and thresholds."""
        input_tensor = constant_op.constant(
            [b"the", b"answer", b"to", b"life", b"and", b"universe"])
        keys = constant_op.constant(
            [b"and", b"life", b"the", b"to", b"universe"])
        values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
        vocab_freq_table = lookup.HashTable(
            lookup.KeyValueTensorInitializer(keys, values), -1)

        with self.test_session():
            vocab_freq_table.init.run()

            # No vocab_freq_table specified - output should be the same as input.
            no_table_output = skip_gram_ops._filter_input(
                input_tensor=input_tensor,
                vocab_freq_table=None,
                vocab_min_count=None,
                vocab_subsampling=None,
                corpus_size=None,
                seed=None)
            self.assertAllEqual(input_tensor.eval(), no_table_output.eval())

            # vocab_freq_table specified, but no vocab_min_count - output should have
            # filtered out tokens not in the table (b"answer").
            table_output = skip_gram_ops._filter_input(
                input_tensor=input_tensor,
                vocab_freq_table=vocab_freq_table,
                vocab_min_count=None,
                vocab_subsampling=None,
                corpus_size=None,
                seed=None)
            self.assertAllEqual([b"the", b"to", b"life", b"and", b"universe"],
                                table_output.eval())

            # vocab_freq_table and vocab_min_count specified - output should have
            # filtered out tokens whose frequencies are below the threshold
            # (b"and": 0, b"life": 1).
            threshold_output = skip_gram_ops._filter_input(
                input_tensor=input_tensor,
                vocab_freq_table=vocab_freq_table,
                vocab_min_count=2,
                vocab_subsampling=None,
                corpus_size=None,
                seed=None)
            self.assertAllEqual([b"the", b"to", b"universe"],
                                threshold_output.eval())
예제 #4
0
  def test_filter_input_filter_vocab(self):
    """Tests input filtering based on vocab frequency table and thresholds."""
    input_tensor = constant_op.constant(
        [b"the", b"answer", b"to", b"life", b"and", b"universe"])
    keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"])
    values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
    vocab_freq_table = lookup.HashTable(
        lookup.KeyValueTensorInitializer(keys, values), -1)

    with self.test_session():
      vocab_freq_table.init.run()

      # No vocab_freq_table specified - output should be the same as input.
      no_table_output = skip_gram_ops._filter_input(
          input_tensor=input_tensor,
          vocab_freq_table=None,
          vocab_min_count=None,
          vocab_subsampling=None,
          corpus_size=None,
          seed=None)
      self.assertAllEqual(input_tensor.eval(), no_table_output.eval())

      # vocab_freq_table specified, but no vocab_min_count - output should have
      # filtered out tokens not in the table (b"answer").
      table_output = skip_gram_ops._filter_input(
          input_tensor=input_tensor,
          vocab_freq_table=vocab_freq_table,
          vocab_min_count=None,
          vocab_subsampling=None,
          corpus_size=None,
          seed=None)
      self.assertAllEqual([b"the", b"to", b"life", b"and", b"universe"],
                          table_output.eval())

      # vocab_freq_table and vocab_min_count specified - output should have
      # filtered out tokens whose frequencies are below the threshold
      # (b"and": 0, b"life": 1).
      threshold_output = skip_gram_ops._filter_input(
          input_tensor=input_tensor,
          vocab_freq_table=vocab_freq_table,
          vocab_min_count=2,
          vocab_subsampling=None,
          corpus_size=None,
          seed=None)
      self.assertAllEqual([b"the", b"to", b"universe"], threshold_output.eval())