def _infer(self, request):
    """Returns JSON for the `vz-line-chart`s for a feature.

    Args:
      request: A request that should contain 'inference_address', 'model_name',
        'model_type, 'model_version', 'model_signature' and 'label_vocab_path'.

    Returns:
      A list of JSON objects, one for each chart.
    """
    vocab_path = request.args.get('label_vocab_path')
    if vocab_path:
      try:
        with tf.gfile.GFile(vocab_path, 'r') as f:
          label_vocab = [line.rstrip('\n') for line in f]
      except tf.errors.NotFoundError as err:
        tf.logging.error('error reading vocab file: %s', err)
        label_vocab = []
    else:
      label_vocab = []

    try:
      if request.method != 'GET':
        tf.logging.error('%s requests are forbidden.', request.method)
        return http_util.Respond(request, {'error': 'invalid non-GET request'},
                                 'application/json', code=405)

      serving_bundle = inference_utils.ServingBundle(
          request.args.get('inference_address'),
          request.args.get('model_name'), request.args.get('model_type'),
          request.args.get('model_version'),
          request.args.get('model_signature'),
          request.args.get('use_predict') == 'true',
          request.args.get('predict_input_tensor'),
          request.args.get('predict_output_tensor'))
      indices_to_infer = sorted(self.updated_example_indices)
      examples_to_infer = [self.examples[index] for index in indices_to_infer]

      # Get inference results proto and combine with indices of inferred
      # examples and respond with this data as json.
      inference_result_proto = platform_utils.call_servo(
          examples_to_infer, serving_bundle)
      new_inferences = inference_utils.wrap_inference_results(
          inference_result_proto)
      infer_json = json_format.MessageToJson(
          new_inferences, including_default_value_fields=True)
      infer_obj = json.loads(infer_json)
      resp = {'indices': indices_to_infer, 'results': infer_obj}
      self.updated_example_indices = set()
      return http_util.Respond(request, {'inferences': json.dumps(resp),
                                         'vocab': json.dumps(label_vocab)},
                               'application/json')
    except common_utils.InvalidUserInputError as e:
      return http_util.Respond(request, {'error': e.message},
                               'application/json', code=400)
    except AbortionError as e:
      return http_util.Respond(request, {'error': e.details},
                               'application/json', code=400)
Ejemplo n.º 2
0
  def test_wrap_inference_results_regression(self):
    """Test wrapping a regression result."""
    inference_result_proto = regression_pb2.RegressionResponse()
    regression = inference_result_proto.result.regressions.add()
    regression.value = 0.45
    regression = inference_result_proto.result.regressions.add()
    regression.value = 0.55

    wrapped = inference_utils.wrap_inference_results(inference_result_proto)
    self.assertEqual(2, len(wrapped.regression_result.regressions))
  def _infer(self, request):
    """Returns JSON for the `vz-line-chart`s for a feature.

    Args:
      request: A request that should contain 'inference_address', 'model_name',
        'model_type, 'model_version', 'model_signature' and 'label_vocab_path'.

    Returns:
      A list of JSON objects, one for each chart.
    """
    vocab_path = request.args.get('label_vocab_path')
    if vocab_path:
      try:
        with tf.gfile.GFile(vocab_path, 'r') as f:
          label_vocab = [line.rstrip('\n') for line in f]
      except tf.errors.NotFoundError as err:
        tf.logging.error('error reading vocab file: %s', err)
        label_vocab = []
    else:
      label_vocab = []

    try:
      if request.method != 'GET':
        tf.logging.error('%s requests are forbidden.', request.method)
        return http_util.Respond(request, {'error': 'invalid non-GET request'},
                                 'application/json', code=405)

      serving_bundle = inference_utils.ServingBundle(
          request.args.get('inference_address'),
          request.args.get('model_name'), request.args.get('model_type'),
          request.args.get('model_version'),
          request.args.get('model_signature'))
      indices_to_infer = sorted(self.updated_example_indices)
      examples_to_infer = [self.examples[index] for index in indices_to_infer]

      # Get inference results proto and combine with indices of inferred
      # examples and respond with this data as json.
      inference_result_proto = platform_utils.call_servo(
          examples_to_infer, serving_bundle)
      new_inferences = inference_utils.wrap_inference_results(
          inference_result_proto)
      infer_json = json_format.MessageToJson(
          new_inferences, including_default_value_fields=True)
      infer_obj = json.loads(infer_json)
      resp = {'indices': indices_to_infer, 'results': infer_obj}
      self.updated_example_indices = set()
      return http_util.Respond(request, {'inferences': json.dumps(resp),
                                         'vocab': json.dumps(label_vocab)},
                               'application/json')
    except common_utils.InvalidUserInputError as e:
      return http_util.Respond(request, {'error': e.message},
                               'application/json', code=400)
    except AbortionError as e:
      return http_util.Respond(request, {'error': e.details},
                               'application/json', code=400)
Ejemplo n.º 4
0
  def test_wrap_inference_results_classification(self):
    """Test wrapping a classification result."""
    inference_result_proto = classification_pb2.ClassificationResponse()
    classification = inference_result_proto.result.classifications.add()
    inference_class = classification.classes.add()
    inference_class.label = 'class_b'
    inference_class.score = 0.3
    inference_class = classification.classes.add()
    inference_class.label = 'class_a'
    inference_class.score = 0.7

    wrapped = inference_utils.wrap_inference_results(inference_result_proto)
    self.assertEqual(1, len(wrapped.classification_result.classifications))
    self.assertEqual(
        2, len(wrapped.classification_result.classifications[0].classes))