def Run(self, args): """This is what gets called when the user runs this command. Args: args: an argparse namespace. All the arguments that were provided to this command invocation. Returns: Some value that we want to have printed later. """ instances = predict_utilities.ReadInstancesFromArgs( args.json_instances, args.text_instances, limit=INPUT_INSTANCES_LIMIT) model_or_version_ref = predict_utilities.ParseModelOrVersionRef( args.model, args.version) results = predict.Predict(model_or_version_ref, instances) if not args.IsSpecified('format'): # default format is based on the response. args.format = predict_utilities.GetDefaultFormat( results.get('predictions')) return results
def _Run(args): """This is what gets called when the user runs this command. Args: args: an argparse namespace. All the arguments that were provided to this command invocation. Returns: A json object that contains predictions. """ instances = predict_utilities.ReadInstancesFromArgs( args.json_request, args.json_instances, args.text_instances, limit=INPUT_INSTANCES_LIMIT) with endpoint_util.MlEndpointOverrides(region=args.region): model_or_version_ref = predict_utilities.ParseModelOrVersionRef( args.model, args.version) if (args.signature_name is None and predict_utilities.CheckRuntimeVersion(args.model, args.version)): log.status.Print( 'You are running on a runtime version >= 1.8. ' 'If the signature defined in the model is ' 'not serving_default then you must specify it via ' '--signature-name flag, otherwise the command may fail.') results = predict.Predict( model_or_version_ref, instances, signature_name=args.signature_name) if not args.IsSpecified('format'): # default format is based on the response. args.format = predict_utilities.GetDefaultFormat( results.get('predictions')) return results
def testPredictTextInstancesWithJSON(self): self.mock_http.request.return_value = [ self.http_response, self.http_body ] result = predict.Predict(self.version_ref, instances=[ '{"images": [0, 1], "key": 3}', '{"images": [0.3, 0.2], "key": 2}', '{"images": [0.2, 0.1], "key": 1}' ]) url = (self._BASE_URL + 'projects/fake-project/' 'models/my_model/versions/v1:predict') method = 'POST' headers = {'Content-Type': 'application/json'} self.mock_http.request.assert_called_once_with( uri=url, method=method, body=('{"instances": ["{\\"images\\": [0, 1], \\"key\\": 3}", ' '"{\\"images\\": [0.3, 0.2], \\"key\\": 2}", ' '"{\\"images\\": [0.2, 0.1], \\"key\\": 1}"]}'), headers=headers) self.assertEqual(json.loads(self.http_body), result)
def testPredictMultipleJsonInstances(self): self.mock_http.request.return_value = [ self.http_response, self.http_body ] test_instances = [{ 'images': [0, 1], 'key': 3 }, { 'images': [2, 3], 'key': 2 }, { 'images': [3, 1], 'key': 1 }] result = predict.Predict(self.version_ref, test_instances) url = (self._BASE_URL + 'projects/fake-project/' 'models/my_model/versions/v1:predict') method = 'POST' headers = {'Content-Type': 'application/json'} self.mock_http.request.assert_called_once_with( uri=url, method=method, body=('{"instances": [{"images": [0, 1], "key": 3}, ' '{"images": [2, 3], "key": 2}, ' '{"images": [3, 1], "key": 1}]}'), headers=headers) self.assertEqual(json.loads(self.http_body), result)
def testPredictFailedRequest(self): failed_response = {'status': '502'} failed_response_body = 'Error 502' self.mock_http.request.return_value = [ failed_response, failed_response_body ] with self.assertRaisesRegex( core_exceptions.Error, 'HTTP request failed. Response: Error 502'): predict.Predict(self.version_ref, self.test_instances)
def testPredictInvalidResponse(self): invalid_http_body = 'abcd' # invalid json dump self.mock_http.request.return_value = [ self.http_response, invalid_http_body ] with self.assertRaisesRegex( core_exceptions.Error, 'No JSON object could be decoded from the ' 'HTTP response body: abcd'): predict.Predict(self.version_ref, self.test_instances)
def testPredictJsonInstances(self): self.mock_http.request.return_value = [ self.http_response, self.http_body ] result = predict.Predict(self.version_ref, self.test_instances) url = (self._BASE_URL + 'projects/fake-project/' 'models/my_model/versions/v1:predict') method = 'POST' headers = {'Content-Type': 'application/json'} self.mock_http.request.assert_called_once_with(uri=url, method=method, body=self.expected_body, headers=headers) self.assertEqual(json.loads(self.http_body), result)
def testPredictionSignatureName(self): self.mock_http.request.return_value = [ self.http_response, self.http_body ] result = predict.Predict(self.version_ref, ['2, 3'], signature_name='my-custom-signature') url = (self._BASE_URL + 'projects/fake-project/' 'models/my_model/versions/v1:predict') method = 'POST' headers = {'Content-Type': 'application/json'} self.mock_http.request.assert_called_once_with( uri=url, method=method, body= '{"instances": ["2, 3"], "signature_name": "my-custom-signature"}', headers=headers) self.assertEqual(json.loads(self.http_body), result)
def testPredictNonUtf8Instances(self): with self.assertRaisesRegex(core_exceptions.Error, 'Instances cannot be JSON encoded'): predict.Predict(self.version_ref, [b'\x89PNG'])