예제 #1
0
  def testDoSkippedModelCreation(self, mock_runner, mock_run_model_inference,
                                 _):
    input_dict = {
        'examples': [self._examples],
        'model': [self._model],
        'model_blessing': [self._model_blessing],
    }
    output_dict = {
        'inference_result': [self._inference_result],
    }
    ai_platform_serving_args = {
        'model_name': 'model_name',
        'project_id': 'project_id'
    }
    # Create exe properties.
    exec_properties = {
        'data_spec':
            proto_utils.proto_to_json(bulk_inferrer_pb2.DataSpec()),
        'custom_config':
            json_utils.dumps(
                {executor.SERVING_ARGS_KEY: ai_platform_serving_args}),
    }
    mock_runner.get_service_name_and_api_version.return_value = ('ml', 'v1')
    mock_runner.create_model_for_aip_prediction_if_not_exist.return_value = False

    # Run executor.
    bulk_inferrer = executor.Executor(self._context)
    bulk_inferrer.Do(input_dict, output_dict, exec_properties)

    ai_platform_prediction_model_spec = (
        model_spec_pb2.AIPlatformPredictionModelSpec(
            project_id='project_id',
            model_name='model_name',
            version_name=self._model_version))
    ai_platform_prediction_model_spec.use_serialization_config = True
    inference_endpoint = model_spec_pb2.InferenceSpecType()
    inference_endpoint.ai_platform_prediction_model_spec.CopyFrom(
        ai_platform_prediction_model_spec)
    mock_run_model_inference.assert_called_once_with(mock.ANY, mock.ANY,
                                                     mock.ANY, mock.ANY,
                                                     mock.ANY,
                                                     inference_endpoint)
    executor_class_path = '%s.%s' % (bulk_inferrer.__class__.__module__,
                                     bulk_inferrer.__class__.__name__)
    with telemetry_utils.scoped_labels(
        {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}):
      job_labels = telemetry_utils.make_labels_dict()
    mock_runner.deploy_model_for_aip_prediction.assert_called_once_with(
        serving_path=path_utils.serving_model_path(self._model.uri),
        model_version_name=mock.ANY,
        ai_platform_serving_args=ai_platform_serving_args,
        labels=job_labels,
        api=mock.ANY,
        skip_model_endpoint_creation=True,
        set_default=False)
    mock_runner.delete_model_from_aip_if_exists.assert_called_once_with(
        model_version_name=mock.ANY,
        ai_platform_serving_args=ai_platform_serving_args,
        api=mock.ANY,
        delete_model_endpoint=False)
예제 #2
0
    def testDoFailedModelDeployment(self, mock_runner,
                                    mock_run_model_inference, _):
        input_dict = {
            'examples': [self._examples],
            'model': [self._model],
            'model_blessing': [self._model_blessing],
        }
        output_dict = {
            'inference_result': [self._inference_result],
        }
        ai_platform_serving_args = {
            'model_name': 'model_name',
            'project_id': 'project_id'
        }
        # Create exe properties.
        exec_properties = {
            'data_spec':
            json_format.MessageToJson(bulk_inferrer_pb2.DataSpec(),
                                      preserving_proto_field_name=True),
            'custom_config':
            json_utils.dumps(
                {executor.SERVING_ARGS_KEY: ai_platform_serving_args}),
        }
        mock_runner.deploy_model_for_aip_prediction.side_effect = (
            Exception('Deployment failed'))
        mock_runner.get_service_name_and_api_version.return_value = ('ml',
                                                                     'v1')
        mock_runner.create_model_for_aip_prediction_if_not_exist.return_value = True

        bulk_inferrer = executor.Executor(self._context)
        with self.assertRaises(Exception):
            bulk_inferrer.Do(input_dict, output_dict, exec_properties)

        mock_runner.delete_model_version_from_aip_if_exists.assert_called_once_with(
            mock.ANY, mock.ANY, ai_platform_serving_args)
        mock_runner.delete_model_from_aip_if_exists.assert_called_once_with(
            mock.ANY, ai_platform_serving_args)