示例#1
0
    def test_skip_gram_sample_with_text_vocab_errors(self):
        """Tests various errors raised by
        skip_gram_sample_with_text_vocab()."""
        dummy_input = tf.constant([""])
        vocab_freq_file = self._make_text_vocab_freq_file()

        invalid_indices = (
            # vocab_token_index can't be negative.
            (-1, 0),
            # vocab_freq_index can't be negative.
            (0, -1),
            # vocab_token_index can't be equal to vocab_freq_index.
            (0, 0),
            (1, 1),
            # vocab_freq_file only has two columns.
            (0, 2),
            (2, 0))

        for vocab_token_index, vocab_freq_index in invalid_indices:
            with self.assertRaises(ValueError):
                text.skip_gram_sample_with_text_vocab(
                    input_tensor=dummy_input,
                    vocab_freq_file=vocab_freq_file,
                    vocab_token_index=vocab_token_index,
                    vocab_freq_index=vocab_freq_index)
示例#2
0
    def test_skip_gram_sample_with_text_vocab_filter_vocab(self):
        """Tests skip-gram sampling with text vocab and freq threshold
        filtering."""
        input_tensor = tf.constant([
            b"the",
            b"answer",  # Will be filtered before candidate generation.
            b"to",
            b"life",
            b"and",
            b"universe"  # Will be filtered before candidate generation.
        ])

        # b"answer" is not in vocab file, and b"universe"'s frequency is below
        # threshold of 3.
        vocab_freq_file = self._make_text_vocab_freq_file()

        tokens, labels = text.skip_gram_sample_with_text_vocab(
            input_tensor=input_tensor,
            vocab_freq_file=vocab_freq_file,
            vocab_token_index=0,
            vocab_freq_index=1,
            vocab_min_count=3,
            min_skips=1,
            max_skips=1)

        expected_tokens, expected_labels = self._split_tokens_labels([
            (b"the", b"to"),
            (b"to", b"the"),
            (b"to", b"life"),
            (b"life", b"to"),
            (b"life", b"and"),
            (b"and", b"life"),
        ])
        self.assertAllEqual(expected_tokens, tokens)
        self.assertAllEqual(expected_labels, labels)
示例#3
0
def _skip_gram_sample_with_text_vocab_errors(vocab_freq_file):
    dummy_input = tf.constant([""])
    invalid_indices = (
        # vocab_token_index can't be negative.
        (-1, 0),
        # vocab_freq_index can't be negative.
        (0, -1),
        # vocab_token_index can't be equal to vocab_freq_index.
        (0, 0),
        (1, 1),
        # vocab_freq_file only has two columns.
        (0, 2),
        (2, 0),
    )

    for vocab_token_index, vocab_freq_index in invalid_indices:
        with pytest.raises(ValueError):
            text.skip_gram_sample_with_text_vocab(
                input_tensor=dummy_input,
                vocab_freq_file=vocab_freq_file,
                vocab_token_index=vocab_token_index,
                vocab_freq_index=vocab_freq_index,
            )
示例#4
0
    def _text_vocab_subsample_vocab_helper(self,
                                           vocab_freq_file,
                                           vocab_min_count,
                                           vocab_freq_dtype,
                                           corpus_size=None):
        # The outputs are non-deterministic, so set random seed to help ensure
        # that the outputs remain constant for testing.
        tf.compat.v1.set_random_seed(42)

        input_tensor = tf.constant([
            # keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57.
            b"the",
            b"answer",  # Not in vocab. (Always discarded)
            b"to",  # keep_prob = 0.75.
            b"life",  # keep_prob > 1. (Always kept)
            b"and",  # keep_prob = 0.48.
            b"universe",  # Below vocab threshold of 3. (Always discarded)
        ])
        # keep_prob calculated from vocab file with relative frequencies of:
        # and: 40
        # life: 8
        # the: 30
        # to: 20
        # universe: 2

        tokens, labels = text.skip_gram_sample_with_text_vocab(
            input_tensor=input_tensor,
            vocab_freq_file=vocab_freq_file,
            vocab_token_index=0,
            vocab_freq_index=1,
            vocab_freq_dtype=tf.dtypes.float64,
            vocab_min_count=vocab_min_count,
            vocab_subsampling=0.05,
            corpus_size=corpus_size,
            min_skips=1,
            max_skips=1,
            seed=123,
        )

        expected_tokens, expected_labels = self._split_tokens_labels([
            (b"the", b"to"),
            (b"to", b"the"),
            (b"to", b"life"),
            (b"life", b"to"),
        ])
        self.assertAllEqual(expected_tokens, tokens)
        self.assertAllEqual(expected_labels, labels)
示例#5
0
def test_skip_gram_sample_with_text_vocab_filter_vocab():
    """Tests skip-gram sampling with text vocab and freq threshold
    filtering."""
    input_tensor = tf.constant(
        [
            b"the",
            b"answer",  # Will be filtered before candidate generation.
            b"to",
            b"life",
            b"and",
            b"universe",  # Will be filtered before candidate generation.
        ]
    )

    # b"answer" is not in vocab file, and b"universe"'s frequency is below
    # threshold of 3.
    with tempfile.TemporaryDirectory() as tmp_dir:
        vocab_freq_file = _make_text_vocab_freq_file(tmp_dir)

        tokens, labels = text.skip_gram_sample_with_text_vocab(
            input_tensor=input_tensor,
            vocab_freq_file=vocab_freq_file,
            vocab_token_index=0,
            vocab_freq_index=1,
            vocab_min_count=3,
            min_skips=1,
            max_skips=1,
        )

    expected_tokens, expected_labels = _split_tokens_labels(
        [
            (b"the", b"to"),
            (b"to", b"the"),
            (b"to", b"life"),
            (b"life", b"to"),
            (b"life", b"and"),
            (b"and", b"life"),
        ]
    )
    np.testing.assert_equal(np.asanyarray(expected_tokens), tokens.numpy())
    np.testing.assert_equal(np.asanyarray(expected_labels), labels.numpy())