def test_greedy(self, padded_decode):
     greedy_obj = sampling_module.SamplingModule(
         length_normalization_fn=None,
         dtype=tf.float32,
         symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(),
         vocab_size=3,
         max_decode_length=4,
         eos_id=10,
         padded_decode=padded_decode)
     ids, _ = greedy_obj.generate(initial_ids=tf.constant([9, 1]),
                                  initial_cache=self.cache)
     self.assertAllEqual([[9, 1, 2, 2, 2], [1, 1, 1, 2, 2]], ids)
 def test_topk(self, padded_decode):
     top_k_obj = sampling_module.SamplingModule(
         length_normalization_fn=length_norm,
         dtype=tf.float32,
         symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(),
         vocab_size=3,
         max_decode_length=4,
         eos_id=10,
         sample_temperature=tf.constant(0.1),
         top_k=tf.constant(3),
         padded_decode=padded_decode)
     ids, _ = top_k_obj.generate(initial_ids=tf.constant([9, 1]),
                                 initial_cache=self.cache)
     self.assertAllEqual([2, 5], ids.shape)
Beispiel #3
0
 def test_topp(self, padded_decode):
     top_p_obj = sampling_module.SamplingModule(
         length_normalization_fn=length_norm,
         dtype=tf.float32,
         symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(),
         vocab_size=3,
         max_decode_length=4,
         eos_id=10,
         sample_temperature=tf.constant(1.0),
         top_p=tf.constant(0.9),
         padded_decode=padded_decode,
         enable_greedy=False)
     tf.random.set_seed(1)
     ids, _ = top_p_obj.generate(initial_ids=tf.constant([9, 1]),
                                 initial_cache=self.cache)
     top_p_expected = tf.constant([[9, 1, 0, 2, 2], [1, 0, 1, 2, 0]])
     self.assertAllEqual(top_p_expected, ids)
Beispiel #4
0
    def test_sampling_equivalent_greedy(self, padded_decode):
        # Ensure that p=0.0 with no sample temperature is same as greedy.
        top_p_obj = sampling_module.SamplingModule(
            length_normalization_fn=length_norm,
            dtype=tf.float32,
            symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(),
            vocab_size=3,
            max_decode_length=4,
            eos_id=10,
            sample_temperature=0.0,
            top_p=tf.constant(0.0),
            padded_decode=padded_decode,
            enable_greedy=False)
        ids, _ = top_p_obj.generate(initial_ids=tf.constant([9, 1]),
                                    initial_cache=self.cache)
        self.assertAllEqual(greedy_expected, ids)

        # Ensure that k=1 with no sample temperature is same as greedy.
        top_k_obj = sampling_module.SamplingModule(
            length_normalization_fn=length_norm,
            dtype=tf.float32,
            symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(),
            vocab_size=3,
            max_decode_length=4,
            eos_id=10,
            sample_temperature=0.0,
            top_k=tf.constant(1),
            padded_decode=padded_decode,
            enable_greedy=False)
        ids, _ = top_k_obj.generate(initial_ids=tf.constant([9, 1]),
                                    initial_cache=self.cache)

        # Ensure that low sample temperature results in Sharp Distribution (greedy).
        low_temperature_obj = sampling_module.SamplingModule(
            length_normalization_fn=length_norm,
            dtype=tf.float32,
            symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(),
            vocab_size=3,
            max_decode_length=4,
            eos_id=10,
            sample_temperature=0.0001,
            padded_decode=padded_decode)
        ids, _ = low_temperature_obj.generate(initial_ids=tf.constant([9, 1]),
                                              initial_cache=self.cache)
        self.assertAllEqual(greedy_expected, ids)

        # Ensure that high sample temperature results in Flat Distribution (random).
        high_temperature_obj = sampling_module.SamplingModule(
            length_normalization_fn=length_norm,
            dtype=tf.float32,
            symbols_to_logits_fn=self._get_test_symbols_to_logits_fn(),
            vocab_size=3,
            max_decode_length=4,
            eos_id=10,
            sample_temperature=10.0,
            padded_decode=padded_decode,
            enable_greedy=False)
        tf.random.set_seed(1)
        ids, _ = high_temperature_obj.generate(initial_ids=tf.constant([9, 1]),
                                               initial_cache=self.cache)
        expected = tf.constant([[9, 0, 0, 2, 2], [1, 0, 0, 0, 0]])
        self.assertAllEqual(expected, ids)