def _infer(self, request): """Returns JSON for the `vz-line-chart`s for a feature. Args: request: A request that should contain 'inference_address', 'model_name', 'model_type, 'model_version', 'model_signature' and 'label_vocab_path'. Returns: A list of JSON objects, one for each chart. """ vocab_path = request.args.get('label_vocab_path') if vocab_path: try: with tf.gfile.GFile(vocab_path, 'r') as f: label_vocab = [line.rstrip('\n') for line in f] except tf.errors.NotFoundError as err: tf.logging.error('error reading vocab file: %s', err) label_vocab = [] else: label_vocab = [] try: if request.method != 'GET': tf.logging.error('%s requests are forbidden.', request.method) return http_util.Respond(request, {'error': 'invalid non-GET request'}, 'application/json', code=405) serving_bundle = inference_utils.ServingBundle( request.args.get('inference_address'), request.args.get('model_name'), request.args.get('model_type'), request.args.get('model_version'), request.args.get('model_signature'), request.args.get('use_predict') == 'true', request.args.get('predict_input_tensor'), request.args.get('predict_output_tensor')) indices_to_infer = sorted(self.updated_example_indices) examples_to_infer = [self.examples[index] for index in indices_to_infer] # Get inference results proto and combine with indices of inferred # examples and respond with this data as json. inference_result_proto = platform_utils.call_servo( examples_to_infer, serving_bundle) new_inferences = inference_utils.wrap_inference_results( inference_result_proto) infer_json = json_format.MessageToJson( new_inferences, including_default_value_fields=True) infer_obj = json.loads(infer_json) resp = {'indices': indices_to_infer, 'results': infer_obj} self.updated_example_indices = set() return http_util.Respond(request, {'inferences': json.dumps(resp), 'vocab': json.dumps(label_vocab)}, 'application/json') except common_utils.InvalidUserInputError as e: return http_util.Respond(request, {'error': e.message}, 'application/json', code=400) except AbortionError as e: return http_util.Respond(request, {'error': e.details}, 'application/json', code=400)
def test_wrap_inference_results_regression(self): """Test wrapping a regression result.""" inference_result_proto = regression_pb2.RegressionResponse() regression = inference_result_proto.result.regressions.add() regression.value = 0.45 regression = inference_result_proto.result.regressions.add() regression.value = 0.55 wrapped = inference_utils.wrap_inference_results(inference_result_proto) self.assertEqual(2, len(wrapped.regression_result.regressions))
def _infer(self, request): """Returns JSON for the `vz-line-chart`s for a feature. Args: request: A request that should contain 'inference_address', 'model_name', 'model_type, 'model_version', 'model_signature' and 'label_vocab_path'. Returns: A list of JSON objects, one for each chart. """ vocab_path = request.args.get('label_vocab_path') if vocab_path: try: with tf.gfile.GFile(vocab_path, 'r') as f: label_vocab = [line.rstrip('\n') for line in f] except tf.errors.NotFoundError as err: tf.logging.error('error reading vocab file: %s', err) label_vocab = [] else: label_vocab = [] try: if request.method != 'GET': tf.logging.error('%s requests are forbidden.', request.method) return http_util.Respond(request, {'error': 'invalid non-GET request'}, 'application/json', code=405) serving_bundle = inference_utils.ServingBundle( request.args.get('inference_address'), request.args.get('model_name'), request.args.get('model_type'), request.args.get('model_version'), request.args.get('model_signature')) indices_to_infer = sorted(self.updated_example_indices) examples_to_infer = [self.examples[index] for index in indices_to_infer] # Get inference results proto and combine with indices of inferred # examples and respond with this data as json. inference_result_proto = platform_utils.call_servo( examples_to_infer, serving_bundle) new_inferences = inference_utils.wrap_inference_results( inference_result_proto) infer_json = json_format.MessageToJson( new_inferences, including_default_value_fields=True) infer_obj = json.loads(infer_json) resp = {'indices': indices_to_infer, 'results': infer_obj} self.updated_example_indices = set() return http_util.Respond(request, {'inferences': json.dumps(resp), 'vocab': json.dumps(label_vocab)}, 'application/json') except common_utils.InvalidUserInputError as e: return http_util.Respond(request, {'error': e.message}, 'application/json', code=400) except AbortionError as e: return http_util.Respond(request, {'error': e.details}, 'application/json', code=400)
def test_wrap_inference_results_classification(self): """Test wrapping a classification result.""" inference_result_proto = classification_pb2.ClassificationResponse() classification = inference_result_proto.result.classifications.add() inference_class = classification.classes.add() inference_class.label = 'class_b' inference_class.score = 0.3 inference_class = classification.classes.add() inference_class.label = 'class_a' inference_class.score = 0.7 wrapped = inference_utils.wrap_inference_results(inference_result_proto) self.assertEqual(1, len(wrapped.classification_result.classifications)) self.assertEqual( 2, len(wrapped.classification_result.classifications[0].classes))