Exemple #1
0
    def _text_vocab_subsample_vocab_helper(self,
                                           vocab_freq_file,
                                           vocab_min_count,
                                           vocab_freq_dtype,
                                           corpus_size=None):
        # The outputs are non-deterministic, so set random seed to help ensure that
        # the outputs remain constant for testing.
        random_seed.set_random_seed(42)

        input_tensor = constant_op.constant([
            # keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57.
            b"the",
            b"answer",  # Not in vocab. (Always discarded)
            b"to",  # keep_prob = 0.75.
            b"life",  # keep_prob > 1. (Always kept)
            b"and",  # keep_prob = 0.48.
            b"universe"  # Below vocab threshold of 3. (Always discarded)
        ])
        # keep_prob calculated from vocab file with relative frequencies of:
        # and: 40
        # life: 8
        # the: 30
        # to: 20
        # universe: 2

        tokens, labels = text.skip_gram_sample_with_text_vocab(
            input_tensor=input_tensor,
            vocab_freq_file=vocab_freq_file,
            vocab_token_index=0,
            vocab_freq_index=1,
            vocab_freq_dtype=vocab_freq_dtype,
            vocab_min_count=vocab_min_count,
            vocab_subsampling=0.05,
            corpus_size=corpus_size,
            min_skips=1,
            max_skips=1,
            seed=123)

        expected_tokens, expected_labels = self._split_tokens_labels([
            (b"the", b"to"),
            (b"to", b"the"),
            (b"to", b"life"),
            (b"life", b"to"),
        ])
        with self.test_session() as sess:
            lookup_ops.tables_initializer().run()
            tokens_eval, labels_eval = sess.run([tokens, labels])
            self.assertAllEqual(expected_tokens, tokens_eval)
            self.assertAllEqual(expected_labels, labels_eval)
Exemple #2
0
    def test_skip_gram_sample_with_text_vocab_errors(self):
        """Tests various errors raised by skip_gram_sample_with_text_vocab()."""
        dummy_input = constant_op.constant([""])
        vocab_freq_file = self._make_text_vocab_freq_file()

        invalid_indices = (
            # vocab_token_index can't be negative.
            (-1, 0),
            # vocab_freq_index can't be negative.
            (0, -1),
            # vocab_token_index can't be equal to vocab_freq_index.
            (0, 0),
            (1, 1),
            # vocab_freq_file only has two columns.
            (0, 2),
            (2, 0))

        for vocab_token_index, vocab_freq_index in invalid_indices:
            with self.assertRaises(ValueError):
                text.skip_gram_sample_with_text_vocab(
                    input_tensor=dummy_input,
                    vocab_freq_file=vocab_freq_file,
                    vocab_token_index=vocab_token_index,
                    vocab_freq_index=vocab_freq_index)
  def test_skip_gram_sample_with_text_vocab_errors(self):
    """Tests various errors raised by skip_gram_sample_with_text_vocab()."""
    dummy_input = constant_op.constant([""])
    vocab_freq_file = self._make_text_vocab_freq_file()

    invalid_indices = (
        # vocab_token_index can't be negative.
        (-1, 0),
        # vocab_freq_index can't be negative.
        (0, -1),
        # vocab_token_index can't be equal to vocab_freq_index.
        (0, 0),
        (1, 1),
        # vocab_freq_file only has two columns.
        (0, 2),
        (2, 0))

    for vocab_token_index, vocab_freq_index in invalid_indices:
      with self.assertRaises(ValueError):
        text.skip_gram_sample_with_text_vocab(
            input_tensor=dummy_input,
            vocab_freq_file=vocab_freq_file,
            vocab_token_index=vocab_token_index,
            vocab_freq_index=vocab_freq_index)
  def _text_vocab_subsample_vocab_helper(self, vocab_freq_file, vocab_min_count,
                                         vocab_freq_dtype, corpus_size=None):
    # The outputs are non-deterministic, so set random seed to help ensure that
    # the outputs remain constant for testing.
    random_seed.set_random_seed(42)

    input_tensor = constant_op.constant([
        # keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57.
        b"the",
        b"answer",  # Not in vocab. (Always discarded)
        b"to",  # keep_prob = 0.75.
        b"life",  # keep_prob > 1. (Always kept)
        b"and",  # keep_prob = 0.48.
        b"universe"  # Below vocab threshold of 3. (Always discarded)
    ])
    # keep_prob calculated from vocab file with relative frequencies of:
    # and: 40
    # life: 8
    # the: 30
    # to: 20
    # universe: 2

    tokens, labels = text.skip_gram_sample_with_text_vocab(
        input_tensor=input_tensor,
        vocab_freq_file=vocab_freq_file,
        vocab_token_index=0,
        vocab_freq_index=1,
        vocab_freq_dtype=vocab_freq_dtype,
        vocab_min_count=vocab_min_count,
        vocab_subsampling=0.05,
        corpus_size=corpus_size,
        min_skips=1,
        max_skips=1,
        seed=123)

    expected_tokens, expected_labels = self._split_tokens_labels([
        (b"the", b"to"),
        (b"to", b"the"),
        (b"to", b"life"),
        (b"life", b"to"),
    ])
    with self.test_session() as sess:
      lookup_ops.tables_initializer().run()
      tokens_eval, labels_eval = sess.run([tokens, labels])
      self.assertAllEqual(expected_tokens, tokens_eval)
      self.assertAllEqual(expected_labels, labels_eval)
Exemple #5
0
    def test_skip_gram_sample_with_text_vocab_filter_vocab(self):
        """Tests skip-gram sampling with text vocab and freq threshold filtering."""
        input_tensor = constant_op.constant([
            b"the",
            b"answer",  # Will be filtered before candidate generation.
            b"to",
            b"life",
            b"and",
            b"universe"  # Will be filtered before candidate generation.
        ])

        # b"answer" is not in vocab file, and b"universe"'s frequency is below
        # threshold of 3.
        vocab_freq_file = self._make_text_vocab_freq_file()

        tokens, labels = text.skip_gram_sample_with_text_vocab(
            input_tensor=input_tensor,
            vocab_freq_file=vocab_freq_file,
            vocab_token_index=0,
            vocab_freq_index=1,
            vocab_min_count=3,
            min_skips=1,
            max_skips=1)

        expected_tokens, expected_labels = self._split_tokens_labels([
            (b"the", b"to"),
            (b"to", b"the"),
            (b"to", b"life"),
            (b"life", b"to"),
            (b"life", b"and"),
            (b"and", b"life"),
        ])
        with self.test_session():
            lookup_ops.tables_initializer().run()
            self.assertAllEqual(expected_tokens, tokens.eval())
            self.assertAllEqual(expected_labels, labels.eval())
  def test_skip_gram_sample_with_text_vocab_filter_vocab(self):
    """Tests skip-gram sampling with text vocab and freq threshold filtering."""
    input_tensor = constant_op.constant([
        b"the",
        b"answer",  # Will be filtered before candidate generation.
        b"to",
        b"life",
        b"and",
        b"universe"  # Will be filtered before candidate generation.
    ])

    # b"answer" is not in vocab file, and b"universe"'s frequency is below
    # threshold of 3.
    vocab_freq_file = self._make_text_vocab_freq_file()

    tokens, labels = text.skip_gram_sample_with_text_vocab(
        input_tensor=input_tensor,
        vocab_freq_file=vocab_freq_file,
        vocab_token_index=0,
        vocab_freq_index=1,
        vocab_min_count=3,
        min_skips=1,
        max_skips=1)

    expected_tokens, expected_labels = self._split_tokens_labels([
        (b"the", b"to"),
        (b"to", b"the"),
        (b"to", b"life"),
        (b"life", b"to"),
        (b"life", b"and"),
        (b"and", b"life"),
    ])
    with self.test_session():
      lookup_ops.tables_initializer().run()
      self.assertAllEqual(expected_tokens, tokens.eval())
      self.assertAllEqual(expected_labels, labels.eval())