def test_greedy(self, padded_decode): greedy_obj = sampling_module.SamplingModule( length_normalization_fn=None, dtype=tf.float32, symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(), vocab_size=3, max_decode_length=4, eos_id=10, padded_decode=padded_decode) ids, _ = greedy_obj.generate(initial_ids=tf.constant([9, 1]), initial_cache=self.cache) self.assertAllEqual([[9, 1, 2, 2, 2], [1, 1, 1, 2, 2]], ids)
def test_topk(self, padded_decode): top_k_obj = sampling_module.SamplingModule( length_normalization_fn=length_norm, dtype=tf.float32, symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(), vocab_size=3, max_decode_length=4, eos_id=10, sample_temperature=tf.constant(0.1), top_k=tf.constant(3), padded_decode=padded_decode) ids, _ = top_k_obj.generate(initial_ids=tf.constant([9, 1]), initial_cache=self.cache) self.assertAllEqual([2, 5], ids.shape)
def test_topp(self, padded_decode): top_p_obj = sampling_module.SamplingModule( length_normalization_fn=length_norm, dtype=tf.float32, symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(), vocab_size=3, max_decode_length=4, eos_id=10, sample_temperature=tf.constant(1.0), top_p=tf.constant(0.9), padded_decode=padded_decode, enable_greedy=False) tf.random.set_seed(1) ids, _ = top_p_obj.generate(initial_ids=tf.constant([9, 1]), initial_cache=self.cache) top_p_expected = tf.constant([[9, 1, 0, 2, 2], [1, 0, 1, 2, 0]]) self.assertAllEqual(top_p_expected, ids)
def test_sampling_equivalent_greedy(self, padded_decode): # Ensure that p=0.0 with no sample temperature is same as greedy. top_p_obj = sampling_module.SamplingModule( length_normalization_fn=length_norm, dtype=tf.float32, symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(), vocab_size=3, max_decode_length=4, eos_id=10, sample_temperature=0.0, top_p=tf.constant(0.0), padded_decode=padded_decode, enable_greedy=False) ids, _ = top_p_obj.generate(initial_ids=tf.constant([9, 1]), initial_cache=self.cache) self.assertAllEqual(greedy_expected, ids) # Ensure that k=1 with no sample temperature is same as greedy. top_k_obj = sampling_module.SamplingModule( length_normalization_fn=length_norm, dtype=tf.float32, symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(), vocab_size=3, max_decode_length=4, eos_id=10, sample_temperature=0.0, top_k=tf.constant(1), padded_decode=padded_decode, enable_greedy=False) ids, _ = top_k_obj.generate(initial_ids=tf.constant([9, 1]), initial_cache=self.cache) # Ensure that low sample temperature results in Sharp Distribution (greedy). low_temperature_obj = sampling_module.SamplingModule( length_normalization_fn=length_norm, dtype=tf.float32, symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(), vocab_size=3, max_decode_length=4, eos_id=10, sample_temperature=0.0001, padded_decode=padded_decode) ids, _ = low_temperature_obj.generate(initial_ids=tf.constant([9, 1]), initial_cache=self.cache) self.assertAllEqual(greedy_expected, ids) # Ensure that high sample temperature results in Flat Distribution (random). high_temperature_obj = sampling_module.SamplingModule( length_normalization_fn=length_norm, dtype=tf.float32, symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(), vocab_size=3, max_decode_length=4, eos_id=10, sample_temperature=10.0, padded_decode=padded_decode, enable_greedy=False) tf.random.set_seed(1) ids, _ = high_temperature_obj.generate(initial_ids=tf.constant([9, 1]), initial_cache=self.cache) expected = tf.constant([[9, 0, 0, 2, 2], [1, 0, 0, 0, 0]]) self.assertAllEqual(expected, ids)