def test_inference_overriding_task(self): api = InferenceApi( "sentence-transformers/paraphrase-albert-small-v2", task="feature-extraction", ) inputs = "This is an example again" result = api(inputs) self.assertIsInstance(result, list)
def test_inference_with_params(self): api = InferenceApi("typeform/distilbert-base-uncased-mnli") inputs = "I bought a device but it is not working and I would like to get reimbursed!" params = {"candidate_labels": ["refund", "legal", "faq"]} result = api(inputs, params) self.assertIsInstance(result, dict) self.assertTrue("sequence" in result) self.assertTrue("scores" in result)
def test_inference_with_audio(self): api = InferenceApi("facebook/wav2vec2-large-960h-lv60-self") dataset = datasets.load_dataset( "patrickvonplaten/librispeech_asr_dummy", "clean", split="validation" ) data = self.read(dataset["file"][0]) result = api(data=data) self.assertIsInstance(result, dict) self.assertTrue("text" in result)
def test_inference_with_dict_inputs(self): api = InferenceApi("deepset/roberta-base-squad2") inputs = { "question": "What's my name?", "context": "My name is Clara and I live in Berkeley.", } result = api(inputs) self.assertIsInstance(result, dict) self.assertTrue("score" in result) self.assertTrue("answer" in result)
def test_simple_inference(self): api = InferenceApi("bert-base-uncased") inputs = "Hi, I think [MASK] is cool" results = api(inputs) self.assertIsInstance(results, list) result = results[0] self.assertIsInstance(result, dict) self.assertTrue("sequence" in result) self.assertTrue("score" in result)
def test_inference_with_image(self): api = InferenceApi("google/vit-base-patch16-224") dataset = datasets.load_dataset("Narsil/image_dummy", "image", split="test") data = self.read(dataset["file"][0]) result = api(data=data) self.assertIsInstance(result, list) for classification in result: self.assertIsInstance(classification, dict) self.assertTrue("score" in classification) self.assertTrue("label" in classification)
def test_inference_missing_input(self): api = InferenceApi("deepset/roberta-base-squad2") result = api({"question": "What's my name?"}) self.assertIsInstance(result, dict) self.assertTrue("error" in result)
def test_inference_overriding_invalid_task(self): with self.assertRaises( ValueError, msg="Invalid task invalid-task. Make sure it's valid."): InferenceApi("bert-base-uncased", task="invalid-task")