def test_validate_stopping_criteria(self): validate_stopping_criteria( StoppingCriteriaList([MaxLengthCriteria(10)]), 10) with self.assertWarns(UserWarning): validate_stopping_criteria( StoppingCriteriaList([MaxLengthCriteria(10)]), 11) stopping_criteria = validate_stopping_criteria(StoppingCriteriaList(), 11) self.assertEqual(len(stopping_criteria), 1)
def full_inference_greedy( t5_encoder, t5_decoder, input_ids, tokenizer, timing_profile, max_length, use_cuda=True, ): stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length)]) decoder_input_ids = torch.full( (1, 1), tokenizer.convert_tokens_to_ids(tokenizer.pad_token), dtype=torch.int32) if use_cuda: decoder_input_ids = decoder_input_ids.to("cuda") def _e2e(): encoder_last_hidden_state = t5_encoder(input_ids=input_ids) return t5_decoder.greedy_search( input_ids=decoder_input_ids, encoder_hidden_states=encoder_last_hidden_state, stopping_criteria=stopping_criteria, ) full_e2e_median_time = measure_python_inference_code( _e2e, number=timing_profile.number, iterations=timing_profile.iterations, ) return (_e2e(), full_e2e_median_time)
def test_max_new_tokens_criteria(self): criteria = MaxNewTokensCriteria(start_length=5, max_new_tokens=5) input_ids, scores = self._get_tensors(5) self.assertFalse(criteria(input_ids, scores)) input_ids, scores = self._get_tensors(9) self.assertFalse(criteria(input_ids, scores)) input_ids, scores = self._get_tensors(10) self.assertTrue(criteria(input_ids, scores)) criteria_list = StoppingCriteriaList([criteria]) self.assertEqual(criteria_list.max_length, 10)
def test_list_criteria(self): input_ids, scores = self._get_tensors(5) criteria = StoppingCriteriaList([ MaxLengthCriteria(max_length=10), MaxTimeCriteria(max_time=0.1), ]) self.assertFalse(criteria(input_ids, scores)) input_ids, scores = self._get_tensors(9) self.assertFalse(criteria(input_ids, scores)) input_ids, scores = self._get_tensors(10) self.assertTrue(criteria(input_ids, scores))