def test_batch_generation(self):
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2",
                                                  pad_token="</s>",
                                                  padding_side="left")
        inputs = tokenizer(["Hello this is a long string", "Hey"],
                           return_tensors="jax",
                           padding=True,
                           truncation=True)

        model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
        model.do_sample = False
        model.config.pad_token_id = model.config.eos_token_id

        jit_generate = jax.jit(model.generate)

        output_sequences = jit_generate(
            inputs["input_ids"],
            attention_mask=inputs["attention_mask"]).sequences

        output_string = tokenizer.batch_decode(output_sequences,
                                               skip_special_tokens=True)

        expected_string = [
            "Hello this is a long string of words. I'm going to try to explain what I mean.",
            "Hey, I'm not sure if I'm going to be able to do",
        ]

        self.assertListEqual(output_string, expected_string)
Beispiel #2
0
    def check_use_cache_generation(self, config, input_ids):
        prompt_length = 3
        model = FlaxGPT2LMHeadModel(config)
        max_length = 10
        batch_size = 1

        prompt_ids = input_ids[:1, :prompt_length]

        # put all generation logic into one function
        def generate(prompt_ids):
            def first_pass(prompt_ids):
                logits, cache = model(prompt_ids,
                                      past_key_values=past_key_values)[:2]
                next_token = jnp.argmax(logits[:, -1:], axis=-1)
                return next_token, cache

            def greedy_search_cond_fn(state):
                cur_len, _, _, _ = state
                return ~(cur_len == max_length - 1)

            def greedy_search_body_fn(state):
                cur_len, sequences, current_token, cache = state
                next_sequences = lax.dynamic_update_slice(
                    sequences, current_token, (0, cur_len))

                next_logits, next_cache = model(current_token,
                                                past_key_values=cache)[:2]
                next_token = jnp.argmax(next_logits, axis=-1)

                return cur_len + 1, next_sequences, next_token, next_cache

            # init tensor to be filled with generation result
            init_sequences = jnp.zeros((batch_size, max_length), dtype="i4")
            init_sequences = lax.dynamic_update_slice(init_sequences,
                                                      prompt_ids, (0, 0))

            # init past key values for cache
            past_key_values = model.init_cache(batch_size, max_length)

            # first pass with long prompt
            next_token, cache = first_pass(prompt_ids)

            # prepare state for generation loop
            init_state = (jnp.array(prompt_length), init_sequences, next_token,
                          cache)

            # fast generation
            _, output_sequences, final_token, _ = lax.while_loop(
                greedy_search_cond_fn, greedy_search_body_fn, init_state)

            # append last token
            output_sequences = lax.dynamic_update_slice(
                output_sequences, final_token, (0, max_length - 1))

            return output_sequences

        jit_generate = jax.jit(generate)
        output_sequences = jit_generate(prompt_ids)
        self.parent.assertEqual(output_sequences.shape, (1, max_length))