def test_filter_input_filter_vocab():
    """Tests input filtering based on vocab frequency table and
    thresholds."""
    input_tensor = tf.constant(
        [b"the", b"answer", b"to", b"life", b"and", b"universe"])
    keys = tf.constant([b"and", b"life", b"the", b"to", b"universe"])
    values = tf.constant([0, 1, 2, 3, 4], tf.dtypes.int64)
    vocab_freq_table = tf.lookup.StaticHashTable(
        tf.lookup.KeyValueTensorInitializer(keys, values), -1)

    # No vocab_freq_table specified - output should be the same as input
    no_table_output = skip_gram_ops._filter_input(
        input_tensor=input_tensor,
        vocab_freq_table=None,
        vocab_min_count=None,
        vocab_subsampling=None,
        corpus_size=None,
        seed=None,
    )
    np.testing.assert_equal(input_tensor.numpy(),
                            np.asanyarray(no_table_output))

    # vocab_freq_table specified, but no vocab_min_count - output should
    # have filtered out tokens not in the table (b"answer").
    table_output = skip_gram_ops._filter_input(
        input_tensor=input_tensor,
        vocab_freq_table=vocab_freq_table,
        vocab_min_count=None,
        vocab_subsampling=None,
        corpus_size=None,
        seed=None,
    )
    np.testing.assert_equal(
        np.asanyarray([b"the", b"to", b"life", b"and", b"universe"]),
        table_output.numpy(),
    )

    # vocab_freq_table and vocab_min_count specified - output should have
    # filtered out tokens whose frequencies are below the threshold
    # (b"and": 0, b"life": 1).
    threshold_output = skip_gram_ops._filter_input(
        input_tensor=input_tensor,
        vocab_freq_table=vocab_freq_table,
        vocab_min_count=2,
        vocab_subsampling=None,
        corpus_size=None,
        seed=None,
    )
    np.testing.assert_equal(np.asanyarray([b"the", b"to", b"universe"]),
                            threshold_output.numpy())
    def test_filter_input_subsample_vocab(self):
        """Tests input filtering based on vocab subsampling."""
        # The outputs are non-deterministic, so set random seed to help ensure
        # that the outputs remain constant for testing.
        tf.compat.v1.set_random_seed(42)

        input_tensor = tf.constant([
            # keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57.
            b"the",
            b"answer",  # Not in vocab. (Always discarded)
            b"to",  # keep_prob = 0.75.
            b"life",  # keep_prob > 1. (Always kept)
            b"and",  # keep_prob = 0.48.
            b"universe"  # Below vocab threshold of 3. (Always discarded)
        ])
        keys = tf.constant([b"and", b"life", b"the", b"to", b"universe"])
        values = tf.constant([40, 8, 30, 20, 2], tf.dtypes.int64)
        vocab_freq_table = tf.lookup.StaticHashTable(
            tf.lookup.KeyValueTensorInitializer(keys, values), -1)

        output = skip_gram_ops._filter_input(
            input_tensor=input_tensor,
            vocab_freq_table=vocab_freq_table,
            vocab_min_count=3,
            vocab_subsampling=0.05,
            corpus_size=tf.math.reduce_sum(values),
            seed=9)
        self.assertAllEqual([b"the", b"to", b"life", b"and"], output)