Beispiel #1
0
    def test_batch_generation(self):
        tokenizer = XGLMTokenizer.from_pretrained("XGLM", padding_side="left")
        inputs = tokenizer(["Hello this is a long string", "Hey"],
                           return_tensors="np",
                           padding=True,
                           truncation=True)

        model = FlaxXGLMForCausalLM.from_pretrained("facebook/xglm-564M")
        model.config.num_beams = 1
        model.config.do_sample = False

        jit_generate = jax.jit(model.generate)

        output_sequences = jit_generate(
            inputs["input_ids"],
            attention_mask=inputs["attention_mask"]).sequences

        output_string = tokenizer.batch_decode(output_sequences,
                                               skip_special_tokens=True)

        expected_string = [
            "Hello this is a long string of questions, but I'm not sure if I'm",
            "Hey, I'm a newbie to the forum and I'",
        ]

        self.assertListEqual(output_string, expected_string)
    def test_xglm_sample_max_time(self):
        tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
        model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
        model.to(torch_device)

        torch.manual_seed(0)
        tokenized = tokenizer("Today is a nice day and", return_tensors="pt")
        input_ids = tokenized.input_ids.to(torch_device)

        MAX_TIME = 0.15

        start = datetime.datetime.now()
        model.generate(input_ids,
                       do_sample=True,
                       max_time=MAX_TIME,
                       max_length=256)
        duration = datetime.datetime.now() - start
        self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
        self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

        start = datetime.datetime.now()
        model.generate(input_ids,
                       do_sample=False,
                       max_time=MAX_TIME,
                       max_length=256)
        duration = datetime.datetime.now() - start
        self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
        self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

        start = datetime.datetime.now()
        model.generate(input_ids,
                       do_sample=False,
                       num_beams=2,
                       max_time=MAX_TIME,
                       max_length=256)
        duration = datetime.datetime.now() - start
        self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
        self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

        start = datetime.datetime.now()
        model.generate(input_ids,
                       do_sample=True,
                       num_beams=2,
                       max_time=MAX_TIME,
                       max_length=256)
        duration = datetime.datetime.now() - start
        self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
        self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

        start = datetime.datetime.now()
        model.generate(input_ids,
                       do_sample=False,
                       max_time=None,
                       max_length=256)
        duration = datetime.datetime.now() - start
        self.assertGreater(duration,
                           datetime.timedelta(seconds=1.25 * MAX_TIME))
Beispiel #3
0
    def test_xglm_sample(self):
        tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
        model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")

        torch.manual_seed(0)
        tokenized = tokenizer("Today is a nice day and", return_tensors="pt")
        input_ids = tokenized.input_ids
        output_ids = model.generate(input_ids, do_sample=True, num_beams=1)
        output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)

        EXPECTED_OUTPUT_STR = "Today is a nice day and the sun is shining. A nice day with warm rainy"
        self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
    def test_xglm_sample(self):
        tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
        model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
        model.to(torch_device)

        torch.manual_seed(0)
        tokenized = tokenizer("Today is a nice day and", return_tensors="pt")
        input_ids = tokenized.input_ids.to(torch_device)
        output_ids = model.generate(input_ids, do_sample=True, num_beams=1)
        output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)

        EXPECTED_OUTPUT_STR = "Today is a nice day and I am happy to show you all about a recent project for my"
        self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
    def test_batch_generation(self):
        model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
        model.to(torch_device)
        tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")

        tokenizer.padding_side = "left"

        # use different length sentences to test batching
        sentences = [
            "Hello, my dog is a little",
            "Today, I",
        ]

        inputs = tokenizer(sentences, return_tensors="pt", padding=True)
        input_ids = inputs["input_ids"].to(torch_device)

        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=inputs["attention_mask"].to(torch_device),
        )

        inputs_non_padded = tokenizer(
            sentences[0], return_tensors="pt").input_ids.to(torch_device)
        output_non_padded = model.generate(input_ids=inputs_non_padded)

        num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][
            -1].long().sum().cpu().item()
        inputs_padded = tokenizer(
            sentences[1], return_tensors="pt").input_ids.to(torch_device)
        output_padded = model.generate(input_ids=inputs_padded,
                                       max_length=model.config.max_length -
                                       num_paddings)

        batch_out_sentence = tokenizer.batch_decode(outputs,
                                                    skip_special_tokens=True)
        non_padded_sentence = tokenizer.decode(output_non_padded[0],
                                               skip_special_tokens=True)
        padded_sentence = tokenizer.decode(output_padded[0],
                                           skip_special_tokens=True)

        expected_output_sentence = [
            "Hello, my dog is a little bit of a shy one, but he is very friendly",
            "Today, I am going to share with you a few of my favorite things",
        ]
        self.assertListEqual(expected_output_sentence, batch_out_sentence)
        self.assertListEqual(expected_output_sentence,
                             [non_padded_sentence, padded_sentence])
Beispiel #6
0
 def big_tokenizer(self):
     return XGLMTokenizer.from_pretrained("facebook/xglm-564M")