def test_noise_span_to_unique_sentinel(self): vocabulary = test_utils.mock_vocabulary({'foo': 10}, vocab_size=1000) tokens = tf.constant([10, 11, 12, 13, 14, 15]) noise_mask = tf.constant([True, True, False, False, True, False]) expected_output = [999, 12, 13, 998, 15] output = self.evaluate( prep.noise_span_to_unique_sentinel(tokens, noise_mask, vocabulary)) self.assertAllEqual(output, expected_output)
def test_drop_nonnoise_tokens(self): vocabulary = test_utils.mock_vocabulary({'foo': 10}, vocab_size=1000) tokens = tf.constant([10, 11, 12, 13, 14, 15]) noise_mask = tf.constant([True, True, False, False, True, False]) expected_output = [10, 11, 14] output = self.evaluate( prep.drop_nonnoise_tokens(tokens, noise_mask, vocabulary)) self.assertAllEqual(output, expected_output)
def test_noise_token_to_gathered_token(self): tf.random.set_seed(55) vocabulary = test_utils.mock_vocabulary({'foo': 10}, vocab_size=1000) tokens = tf.constant([10, 11, 12, 13, 14, 15]) noise_mask = tf.constant([True, True, False, False, True, False]) expected_output = [11, 11, 12, 13, 15, 15] output = self.evaluate( prep.noise_token_to_gathered_token(tokens, noise_mask, vocabulary)) self.assertAllEqual(output, expected_output)
def test_noise_token_to_random_token_or_sentinel(self): tf.set_random_seed(55) vocabulary = test_utils.mock_vocabulary({'foo': 10}, vocab_size=1000) tokens = tf.constant(list(range(10))) noise_mask = tf.constant( [True, True, False, False, True, False, True, True, True, True]) expected_output = [436, 999, 2, 3, 999, 5, 999, 999, 999, 999] output = self.evaluate(prep.noise_token_to_random_token_or_sentinel( tokens, noise_mask, vocabulary, random_prob=0.2)) self.assertAllEqual(output, expected_output)
def test_permute_noise_tokens(self): tf.set_random_seed(55) vocabulary = test_utils.mock_vocabulary({'foo': 10}, vocab_size=1000) tokens = tf.constant([10, 11, 12, 13, 14, 15]) noise_mask = tf.constant([True, True, False, False, True, False]) if six.PY2: expected_output = [11, 10, 12, 13, 14, 15] else: expected_output = [11, 14, 12, 13, 10, 15] output = self.evaluate( prep.permute_noise_tokens(tokens, noise_mask, vocabulary)) self.assertAllEqual(output, expected_output)