def _get_inference_spec( self, model_path: Text, model_version: Text, ai_platform_serving_args: Dict[Text, Any] ) -> model_spec_pb2.InferenceSpecType: if 'project_id' not in ai_platform_serving_args: raise ValueError( '\'project_id\' is missing in \'ai_platform_serving_args\'') project_id = ai_platform_serving_args['project_id'] if 'model_name' not in ai_platform_serving_args: raise ValueError( '\'model_name\' is missing in \'ai_platform_serving_args\'') model_name = ai_platform_serving_args['model_name'] ai_platform_prediction_model_spec = ( model_spec_pb2.AIPlatformPredictionModelSpec( project_id=project_id, model_name=model_name, version_name=model_version)) model_signature = self._get_model_signature(model_path) if (len(model_signature.inputs) == 1 and list( model_signature.inputs.values())[0].dtype == tf.string.as_datatype_enum ): ai_platform_prediction_model_spec.use_serialization_config = True logging.info( 'Using hosted model on Cloud AI platform, model_name: %s,' 'model_version: %s.', model_name, model_version) result = model_spec_pb2.InferenceSpecType() result.ai_platform_prediction_model_spec.CopyFrom( ai_platform_prediction_model_spec) return result
def testDoSkippedModelCreation(self, mock_runner, mock_run_model_inference, _): input_dict = { 'examples': [self._examples], 'model': [self._model], 'model_blessing': [self._model_blessing], } output_dict = { 'inference_result': [self._inference_result], } ai_platform_serving_args = { 'model_name': 'model_name', 'project_id': 'project_id' } # Create exe properties. exec_properties = { 'data_spec': proto_utils.proto_to_json(bulk_inferrer_pb2.DataSpec()), 'custom_config': json_utils.dumps( {executor.SERVING_ARGS_KEY: ai_platform_serving_args}), } mock_runner.get_service_name_and_api_version.return_value = ('ml', 'v1') mock_runner.create_model_for_aip_prediction_if_not_exist.return_value = False # Run executor. bulk_inferrer = executor.Executor(self._context) bulk_inferrer.Do(input_dict, output_dict, exec_properties) ai_platform_prediction_model_spec = ( model_spec_pb2.AIPlatformPredictionModelSpec( project_id='project_id', model_name='model_name', version_name=self._model_version)) ai_platform_prediction_model_spec.use_serialization_config = True inference_endpoint = model_spec_pb2.InferenceSpecType() inference_endpoint.ai_platform_prediction_model_spec.CopyFrom( ai_platform_prediction_model_spec) mock_run_model_inference.assert_called_once_with(mock.ANY, mock.ANY, mock.ANY, mock.ANY, mock.ANY, inference_endpoint) executor_class_path = '%s.%s' % (bulk_inferrer.__class__.__module__, bulk_inferrer.__class__.__name__) with telemetry_utils.scoped_labels( {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}): job_labels = telemetry_utils.make_labels_dict() mock_runner.deploy_model_for_aip_prediction.assert_called_once_with( serving_path=path_utils.serving_model_path(self._model.uri), model_version_name=mock.ANY, ai_platform_serving_args=ai_platform_serving_args, labels=job_labels, api=mock.ANY, skip_model_endpoint_creation=True, set_default=False) mock_runner.delete_model_from_aip_if_exists.assert_called_once_with( model_version_name=mock.ANY, ai_platform_serving_args=ai_platform_serving_args, api=mock.ANY, delete_model_endpoint=False)