Exemplo n.º 1
0
 def _get_inference_spec(
     self, model_path: Text, model_version: Text,
     ai_platform_serving_args: Dict[Text, Any]
 ) -> model_spec_pb2.InferenceSpecType:
   if 'project_id' not in ai_platform_serving_args:
     raise ValueError(
         '\'project_id\' is missing in \'ai_platform_serving_args\'')
   project_id = ai_platform_serving_args['project_id']
   if 'model_name' not in ai_platform_serving_args:
     raise ValueError(
         '\'model_name\' is missing in \'ai_platform_serving_args\'')
   model_name = ai_platform_serving_args['model_name']
   ai_platform_prediction_model_spec = (
       model_spec_pb2.AIPlatformPredictionModelSpec(
           project_id=project_id,
           model_name=model_name,
           version_name=model_version))
   model_signature = self._get_model_signature(model_path)
   if (len(model_signature.inputs) == 1 and list(
       model_signature.inputs.values())[0].dtype == tf.string.as_datatype_enum
      ):
     ai_platform_prediction_model_spec.use_serialization_config = True
   logging.info(
       'Using hosted model on Cloud AI platform, model_name: %s,'
       'model_version: %s.', model_name, model_version)
   result = model_spec_pb2.InferenceSpecType()
   result.ai_platform_prediction_model_spec.CopyFrom(
       ai_platform_prediction_model_spec)
   return result
Exemplo n.º 2
0
  def testDoSkippedModelCreation(self, mock_runner, mock_run_model_inference,
                                 _):
    input_dict = {
        'examples': [self._examples],
        'model': [self._model],
        'model_blessing': [self._model_blessing],
    }
    output_dict = {
        'inference_result': [self._inference_result],
    }
    ai_platform_serving_args = {
        'model_name': 'model_name',
        'project_id': 'project_id'
    }
    # Create exe properties.
    exec_properties = {
        'data_spec':
            proto_utils.proto_to_json(bulk_inferrer_pb2.DataSpec()),
        'custom_config':
            json_utils.dumps(
                {executor.SERVING_ARGS_KEY: ai_platform_serving_args}),
    }
    mock_runner.get_service_name_and_api_version.return_value = ('ml', 'v1')
    mock_runner.create_model_for_aip_prediction_if_not_exist.return_value = False

    # Run executor.
    bulk_inferrer = executor.Executor(self._context)
    bulk_inferrer.Do(input_dict, output_dict, exec_properties)

    ai_platform_prediction_model_spec = (
        model_spec_pb2.AIPlatformPredictionModelSpec(
            project_id='project_id',
            model_name='model_name',
            version_name=self._model_version))
    ai_platform_prediction_model_spec.use_serialization_config = True
    inference_endpoint = model_spec_pb2.InferenceSpecType()
    inference_endpoint.ai_platform_prediction_model_spec.CopyFrom(
        ai_platform_prediction_model_spec)
    mock_run_model_inference.assert_called_once_with(mock.ANY, mock.ANY,
                                                     mock.ANY, mock.ANY,
                                                     mock.ANY,
                                                     inference_endpoint)
    executor_class_path = '%s.%s' % (bulk_inferrer.__class__.__module__,
                                     bulk_inferrer.__class__.__name__)
    with telemetry_utils.scoped_labels(
        {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}):
      job_labels = telemetry_utils.make_labels_dict()
    mock_runner.deploy_model_for_aip_prediction.assert_called_once_with(
        serving_path=path_utils.serving_model_path(self._model.uri),
        model_version_name=mock.ANY,
        ai_platform_serving_args=ai_platform_serving_args,
        labels=job_labels,
        api=mock.ANY,
        skip_model_endpoint_creation=True,
        set_default=False)
    mock_runner.delete_model_from_aip_if_exists.assert_called_once_with(
        model_version_name=mock.ANY,
        ai_platform_serving_args=ai_platform_serving_args,
        api=mock.ANY,
        delete_model_endpoint=False)