예제 #1
0
    def test_validate_stopping_criteria(self):
        validate_stopping_criteria(
            StoppingCriteriaList([MaxLengthCriteria(10)]), 10)

        with self.assertWarns(UserWarning):
            validate_stopping_criteria(
                StoppingCriteriaList([MaxLengthCriteria(10)]), 11)

        stopping_criteria = validate_stopping_criteria(StoppingCriteriaList(),
                                                       11)

        self.assertEqual(len(stopping_criteria), 1)
예제 #2
0
def full_inference_greedy(
    t5_encoder,
    t5_decoder,
    input_ids,
    tokenizer,
    timing_profile,
    max_length,
    use_cuda=True,
):
    stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length)])
    decoder_input_ids = torch.full(
        (1, 1),
        tokenizer.convert_tokens_to_ids(tokenizer.pad_token),
        dtype=torch.int32)

    if use_cuda:
        decoder_input_ids = decoder_input_ids.to("cuda")

    def _e2e():
        encoder_last_hidden_state = t5_encoder(input_ids=input_ids)

        return t5_decoder.greedy_search(
            input_ids=decoder_input_ids,
            encoder_hidden_states=encoder_last_hidden_state,
            stopping_criteria=stopping_criteria,
        )

    full_e2e_median_time = measure_python_inference_code(
        _e2e,
        number=timing_profile.number,
        iterations=timing_profile.iterations,
    )

    return (_e2e(), full_e2e_median_time)
예제 #3
0
    def test_max_new_tokens_criteria(self):
        criteria = MaxNewTokensCriteria(start_length=5, max_new_tokens=5)

        input_ids, scores = self._get_tensors(5)
        self.assertFalse(criteria(input_ids, scores))

        input_ids, scores = self._get_tensors(9)
        self.assertFalse(criteria(input_ids, scores))

        input_ids, scores = self._get_tensors(10)
        self.assertTrue(criteria(input_ids, scores))

        criteria_list = StoppingCriteriaList([criteria])
        self.assertEqual(criteria_list.max_length, 10)
예제 #4
0
    def test_list_criteria(self):
        input_ids, scores = self._get_tensors(5)

        criteria = StoppingCriteriaList([
            MaxLengthCriteria(max_length=10),
            MaxTimeCriteria(max_time=0.1),
        ])

        self.assertFalse(criteria(input_ids, scores))

        input_ids, scores = self._get_tensors(9)
        self.assertFalse(criteria(input_ids, scores))

        input_ids, scores = self._get_tensors(10)
        self.assertTrue(criteria(input_ids, scores))