Пример #1
0
  def _infer_mutants_impl(self, feature_name, example_index, inference_addresses,
      model_names, model_type, model_versions, model_signatures, use_predict,
      predict_input_tensor, predict_output_tensor, x_min, x_max,
      feature_index_pattern, custom_predict_fn):
    """Helper for generating PD plots for a feature."""
    examples = (self.examples if example_index == -1
                else [self.examples[example_index]])
    serving_bundles = []
    for model_num in xrange(len(inference_addresses)):
      serving_bundles.append(inference_utils.ServingBundle(
          inference_addresses[model_num],
          model_names[model_num],
          model_type,
          model_versions[model_num],
          model_signatures[model_num],
          use_predict,
          predict_input_tensor,
          predict_output_tensor,
          custom_predict_fn=custom_predict_fn))

    viz_params = inference_utils.VizParams(
        x_min, x_max,
        self.examples[0:NUM_EXAMPLES_TO_SCAN], NUM_MUTANTS,
        feature_index_pattern)
    return inference_utils.mutant_charts_for_feature(
        examples, feature_name, serving_bundles, viz_params)
Пример #2
0
    def _infer(self, request):
        """Returns JSON for the `vz-line-chart`s for a feature.

    Args:
      request: A request that should contain 'inference_address', 'model_name',
        'model_type, 'model_version', 'model_signature' and 'label_vocab_path'.

    Returns:
      A list of JSON objects, one for each chart.
    """
        label_vocab = inference_utils.get_label_vocab(
            request.args.get('label_vocab_path'))

        try:
            if request.method != 'GET':
                logger.error('%s requests are forbidden.', request.method)
                return http_util.Respond(request,
                                         {'error': 'invalid non-GET request'},
                                         'application/json',
                                         code=405)

            (inference_addresses, model_names, model_versions,
             model_signatures) = self._parse_request_arguments(request)

            indices_to_infer = sorted(self.updated_example_indices)
            examples_to_infer = [
                self.examples[index] for index in indices_to_infer
            ]
            infer_objs = []
            for model_num in xrange(len(inference_addresses)):
                serving_bundle = inference_utils.ServingBundle(
                    inference_addresses[model_num],
                    model_names[model_num],
                    request.args.get('model_type'),
                    model_versions[model_num],
                    model_signatures[model_num],
                    request.args.get('use_predict') == 'true',
                    request.args.get('predict_input_tensor'),
                    request.args.get('predict_output_tensor'),
                    custom_predict_fn=self.custom_predict_fn)
                (predictions,
                 _) = inference_utils.run_inference_for_inference_results(
                     examples_to_infer, serving_bundle)
                infer_objs.append(predictions)

            resp = {'indices': indices_to_infer, 'results': infer_objs}
            self.updated_example_indices = set()
            return http_util.Respond(request, {
                'inferences': json.dumps(resp),
                'vocab': json.dumps(label_vocab)
            }, 'application/json')
        except common_utils.InvalidUserInputError as e:
            return http_util.Respond(request, {'error': e.message},
                                     'application/json',
                                     code=400)
        except AbortionError as e:
            return http_util.Respond(request, {'error': e.details},
                                     'application/json',
                                     code=400)
Пример #3
0
  def _infer(self, request):
    """Returns JSON for the `vz-line-chart`s for a feature.

    Args:
      request: A request that should contain 'inference_address', 'model_name',
        'model_type, 'model_version', 'model_signature' and 'label_vocab_path'.

    Returns:
      A list of JSON objects, one for each chart.
    """
    start_example = (int(request.args.get('start_example'))
        if request.args.get('start_example') else 0)
    if not start_example:
      label_vocab = inference_utils.get_label_vocab(
        request.args.get('label_vocab_path'))
      try:
        if request.method != 'GET':
          logger.error('%s requests are forbidden.', request.method)
          return http_util.Respond(request, 'invalid non-GET request',
                                      'application/json', code=405)

        (inference_addresses, model_names, model_versions,
            model_signatures) = self._parse_request_arguments(request)

        self.indices_to_infer = sorted(self.updated_example_indices)
        examples_to_infer = [self.examples[index] for index in self.indices_to_infer]
        self.infer_objs = []
        for model_num in xrange(len(inference_addresses)):
          serving_bundle = inference_utils.ServingBundle(
              inference_addresses[model_num],
              model_names[model_num],
              request.args.get('model_type'),
              model_versions[model_num],
              model_signatures[model_num],
              request.args.get('use_predict') == 'true',
              request.args.get('predict_input_tensor'),
              request.args.get('predict_output_tensor'),
              custom_predict_fn=self.custom_predict_fn)
          (predictions, _) = inference_utils.run_inference_for_inference_results(
              examples_to_infer, serving_bundle)
          self.infer_objs.append(predictions)
        self.updated_example_indices = set()
      except AbortionError as e:
        logging.error(str(e))
        return http_util.Respond(request, e.details,
                                'application/json', code=400)
      except Exception as e:
        logging.error(str(e))
        return http_util.Respond(request, str(e),
                                'application/json', code=400)

    # Split results from start_example to + max_examples
    # Send next start_example if necessary
    end_example = start_example + MAX_EXAMPLES_TO_SEND

    def get_inferences_resp():
      sliced_infer_objs = [
        copy.deepcopy(infer_obj) for infer_obj in self.infer_objs]
      if request.args.get('model_type') == 'classification':
        for obj in sliced_infer_objs:
          obj['classificationResult']['classifications'][:] = obj[
            'classificationResult']['classifications'][
              start_example:end_example]
      else:
        for obj in sliced_infer_objs:
          obj['regressionResult']['regressions'][:] = obj['regressionResult'][
            'regressions'][start_example:end_example]
      return {'indices': self.indices_to_infer[start_example:end_example],
              'results': sliced_infer_objs}

    try:
      inferences_resp = get_inferences_resp()
      resp = {'inferences': json.dumps(inferences_resp)}
      if end_example >= len(self.examples):
        end_example = -1
      if start_example == 0:
        resp['vocab'] = json.dumps(label_vocab)
      resp['next'] = end_example
      return http_util.Respond(request, resp, 'application/json')
    except Exception as e:
      logging.error(e)
      return http_util.Respond(request, str(e),
                               'application/json', code=400)