コード例 #1
0
ファイル: runner_test.py プロジェクト: jay90099/tfx
 def setUp(self):
   super().setUp()
   self._output_data_dir = os.path.join(
       os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
       self._testMethodName)
   self._project_id = '12345'
   self._mock_api_client = mock.Mock()
   self._inputs = {}
   self._outputs = {}
   self._training_inputs = {
       'project': self._project_id,
   }
   self._job_id = 'my_jobid'
   # Dict format of exec_properties. custom_config needs to be serialized
   # before being passed into start_cloud_training function.
   self._exec_properties = {
       'custom_config': {
           executor.TRAINING_ARGS_KEY: self._training_inputs,
       },
   }
   self._model_name = 'model_name'
   self._ai_platform_serving_args = {
       'model_name': self._model_name,
       'project_id': self._project_id,
   }
   self._executor_class_path = 'my.executor.Executor'
   with telemetry_utils.scoped_labels(
       {telemetry_utils.LABEL_TFX_EXECUTOR: self._executor_class_path}):
     self._job_labels = telemetry_utils.make_labels_dict()
コード例 #2
0
 def testDoBlessed_Vertex(self, mock_runner):
     endpoint_uri = 'projects/project_id/locations/us-central1/endpoints/12345'
     mock_runner.deploy_model_for_aip_prediction.return_value = endpoint_uri
     self._model_blessing.uri = os.path.join(self._source_data_dir,
                                             'model_validator/blessed')
     self._model_blessing.set_int_custom_property('blessed', 1)
     self._executor.Do(self._input_dict, self._output_dict,
                       self._serialize_custom_config_under_test_vertex())
     executor_class_path = '%s.%s' % (self._executor.__class__.__module__,
                                      self._executor.__class__.__name__)
     with telemetry_utils.scoped_labels(
         {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}):
         job_labels = telemetry_utils.make_labels_dict()
     mock_runner.deploy_model_for_aip_prediction.assert_called_once_with(
         serving_container_image_uri=self._container_image_uri_vertex,
         model_version_name=mock.ANY,
         ai_platform_serving_args=mock.ANY,
         labels=job_labels,
         serving_path=self._model_push.uri,
         endpoint_region='us-central1',
         enable_vertex=True,
     )
     self.assertPushed()
     self.assertEqual(
         self._model_push.get_string_custom_property('pushed_destination'),
         endpoint_uri)
コード例 #3
0
  def testDoSkippedModelCreation(self, mock_runner, mock_run_model_inference,
                                 _):
    input_dict = {
        'examples': [self._examples],
        'model': [self._model],
        'model_blessing': [self._model_blessing],
    }
    output_dict = {
        'inference_result': [self._inference_result],
    }
    ai_platform_serving_args = {
        'model_name': 'model_name',
        'project_id': 'project_id'
    }
    # Create exe properties.
    exec_properties = {
        'data_spec':
            proto_utils.proto_to_json(bulk_inferrer_pb2.DataSpec()),
        'custom_config':
            json_utils.dumps(
                {executor.SERVING_ARGS_KEY: ai_platform_serving_args}),
    }
    mock_runner.get_service_name_and_api_version.return_value = ('ml', 'v1')
    mock_runner.create_model_for_aip_prediction_if_not_exist.return_value = False

    # Run executor.
    bulk_inferrer = executor.Executor(self._context)
    bulk_inferrer.Do(input_dict, output_dict, exec_properties)

    ai_platform_prediction_model_spec = (
        model_spec_pb2.AIPlatformPredictionModelSpec(
            project_id='project_id',
            model_name='model_name',
            version_name=self._model_version))
    ai_platform_prediction_model_spec.use_serialization_config = True
    inference_endpoint = model_spec_pb2.InferenceSpecType()
    inference_endpoint.ai_platform_prediction_model_spec.CopyFrom(
        ai_platform_prediction_model_spec)
    mock_run_model_inference.assert_called_once_with(mock.ANY, mock.ANY,
                                                     mock.ANY, mock.ANY,
                                                     mock.ANY,
                                                     inference_endpoint)
    executor_class_path = '%s.%s' % (bulk_inferrer.__class__.__module__,
                                     bulk_inferrer.__class__.__name__)
    with telemetry_utils.scoped_labels(
        {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}):
      job_labels = telemetry_utils.make_labels_dict()
    mock_runner.deploy_model_for_aip_prediction.assert_called_once_with(
        serving_path=path_utils.serving_model_path(self._model.uri),
        model_version_name=mock.ANY,
        ai_platform_serving_args=ai_platform_serving_args,
        labels=job_labels,
        api=mock.ANY,
        skip_model_endpoint_creation=True,
        set_default=False)
    mock_runner.delete_model_from_aip_if_exists.assert_called_once_with(
        model_version_name=mock.ANY,
        ai_platform_serving_args=ai_platform_serving_args,
        api=mock.ANY,
        delete_model_endpoint=False)
コード例 #4
0
    def testDoBlessed(self, mock_runner, _):
        self._model_blessing.uri = os.path.join(self._source_data_dir,
                                                'model_validator/blessed')
        self._model_blessing.set_int_custom_property('blessed', 1)
        mock_runner.get_service_name_and_api_version.return_value = ('ml',
                                                                     'v1')
        version = self._model_push.get_string_custom_property('pushed_version')
        mock_runner.deploy_model_for_aip_prediction.return_value = (
            'projects/project_id/models/model_name/versions/{}'.format(version)
        )

        self._executor.Do(self._input_dict, self._output_dict,
                          self._serialize_custom_config_under_test())
        executor_class_path = '%s.%s' % (self._executor.__class__.__module__,
                                         self._executor.__class__.__name__)
        with telemetry_utils.scoped_labels(
            {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}):
            job_labels = telemetry_utils.make_labels_dict()
        mock_runner.deploy_model_for_aip_prediction.assert_called_once_with(
            serving_path=self._model_push.uri,
            model_version_name=mock.ANY,
            ai_platform_serving_args=mock.ANY,
            api=mock.ANY,
            labels=job_labels,
        )
        self.assertPushed()
        self.assertEqual(
            self._model_push.get_string_custom_property('pushed_destination'),
            'projects/project_id/models/model_name/versions/{}'.format(
                version))
コード例 #5
0
ファイル: telemetry_utils_test.py プロジェクト: jay90099/tfx
 def testScopedLabels(self):
     """Test for scoped_labels."""
     orig_labels = telemetry_utils.make_labels_dict()
     with telemetry_utils.scoped_labels({'foo': 'bar'}):
         self.assertDictEqual(telemetry_utils.make_labels_dict(),
                              dict({'foo': 'bar'}, **orig_labels))
         with telemetry_utils.scoped_labels({
                 telemetry_utils.LABEL_TFX_EXECUTOR:
                 'custom_component.custom_executor'
         }):
             self.assertDictEqual(
                 telemetry_utils.make_labels_dict(),
                 dict(
                     {
                         'foo':
                         'bar',
                         telemetry_utils.LABEL_TFX_EXECUTOR:
                         'third_party_executor'
                     }, **orig_labels))
         with telemetry_utils.scoped_labels({
                 telemetry_utils.LABEL_TFX_EXECUTOR:
                 'tfx.components.example_gen.import_example_gen.executor.Executor'
         }):
             self.assertDictEqual(
                 telemetry_utils.make_labels_dict(),
                 dict(
                     {
                         'foo':
                             'bar',
                         telemetry_utils.LABEL_TFX_EXECUTOR:  # Label is normalized.
                             'tfx-components-example_gen-import_example_gen-executor-executor'
                     },
                     **orig_labels))
         with telemetry_utils.scoped_labels({
                 telemetry_utils.LABEL_TFX_EXECUTOR:
                 'tfx.extensions.google_cloud_big_query.example_gen.executor.Executor'
         }):
             self.assertDictEqual(
                 telemetry_utils.make_labels_dict(),
                 dict(
                     {
                         'foo':
                             'bar',
                         telemetry_utils.LABEL_TFX_EXECUTOR:  # Label is normalized.
                             'tfx-extensions-google_cloud_big_query-example_gen-executor-exec'
                     },
                     **orig_labels))
コード例 #6
0
  def create_training_job(self, input_dict: Dict[str, List[types.Artifact]],
                          output_dict: Dict[str, List[types.Artifact]],
                          exec_properties: Dict[str, Any],
                          executor_class_path: str, job_args: Dict[str, Any],
                          job_id: Optional[str]) -> Dict[str, Any]:
    """Get training args for runner._launch_aip_training.

    The training args contain the inputs/outputs/exec_properties to the
    tfx.scripts.run_executor module.

    Args:
      input_dict: Passthrough input dict for tfx.components.Trainer.executor.
      output_dict: Passthrough input dict for tfx.components.Trainer.executor.
      exec_properties: Passthrough input dict for
        tfx.components.Trainer.executor.
      executor_class_path: class path for TFX core default trainer.
      job_args: Training input argument for AI Platform training job.
        'pythonModule', 'pythonVersion' and 'runtimeVersion' will be inferred.
        For the full set of parameters, refer to
        https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#TrainingInput
      job_id: Job ID for AI Platform Training job. If not supplied,
        system-determined unique ID is given. Refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#resource-job

    Returns:
      A dict containing the training arguments
    """
    training_inputs = job_args.copy()

    container_command = self.generate_container_command(input_dict, output_dict,
                                                        exec_properties,
                                                        executor_class_path)

    if not training_inputs.get('masterConfig'):
      training_inputs['masterConfig'] = {
          'imageUri': _TFX_IMAGE,
      }

    # Always use our own entrypoint instead of relying on container default.
    if 'containerCommand' in training_inputs['masterConfig']:
      logging.warn('Overriding custom value of containerCommand')
    training_inputs['masterConfig']['containerCommand'] = container_command

    with telemetry_utils.scoped_labels(
        {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}):
      job_labels = telemetry_utils.make_labels_dict()

    # 'tfx_YYYYmmddHHMMSS' is the default job ID if not explicitly specified.
    job_id = job_id or 'tfx_{}'.format(
        datetime.datetime.now().strftime('%Y%m%d%H%M%S'))

    caip_job = {
        'job_id': job_id,
        'training_input': training_inputs,
        'labels': job_labels
    }

    return caip_job
コード例 #7
0
def ReadFromBigQuery(pipeline: beam.Pipeline,
                     query: str) -> beam.pvalue.PCollection:
    """Read data from BigQuery.

  Args:
    pipeline: Beam pipeline.
    query: A BigQuery sql string.

  Returns:
    PCollection of dict.
  """
    return (pipeline
            | 'ReadFromBigQuery' >> bigquery.ReadFromBigQuery(
                query=query,
                use_standard_sql=True,
                bigquery_job_labels=telemetry_utils.make_labels_dict()))
コード例 #8
0
 def testDoBlessedOnRegionalEndpoint_Vertex(self, mock_runner):
     endpoint_uri = 'projects/project_id/locations/us-west1/endpoints/12345'
     mock_runner.deploy_model_for_aip_prediction.return_value = endpoint_uri
     self._exec_properties_vertex = {
         'custom_config': {
             constants.SERVING_ARGS_KEY: {
                 'model_name': 'model_name',
                 'project_id': 'project_id'
             },
             constants.VERTEX_CONTAINER_IMAGE_URI_KEY:
             self._container_image_uri_vertex,
             constants.ENABLE_VERTEX_KEY: True,
             constants.VERTEX_REGION_KEY: 'us-west1',
         },
     }
     self._model_blessing.uri = os.path.join(self._source_data_dir,
                                             'model_validator/blessed')
     self._model_blessing.set_int_custom_property('blessed', 1)
     self._executor.Do(self._input_dict, self._output_dict,
                       self._serialize_custom_config_under_test_vertex())
     executor_class_path = '%s.%s' % (self._executor.__class__.__module__,
                                      self._executor.__class__.__name__)
     with telemetry_utils.scoped_labels(
         {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}):
         job_labels = telemetry_utils.make_labels_dict()
     mock_runner.deploy_model_for_aip_prediction.assert_called_once_with(
         serving_path=self._model_push.uri,
         model_version_name=mock.ANY,
         ai_platform_serving_args=mock.ANY,
         labels=job_labels,
         serving_container_image_uri=self._container_image_uri_vertex,
         endpoint_region='us-west1',
         enable_vertex=True,
     )
     self.assertPushed()
     self.assertEqual(
         self._model_push.get_string_custom_property('pushed_destination'),
         endpoint_uri)
コード例 #9
0
    def Do(self, input_dict: Dict[str, List[types.Artifact]],
           output_dict: Dict[str, List[types.Artifact]],
           exec_properties: Dict[str, Any]):
        """Overrides the tfx_pusher_executor.

    Args:
      input_dict: Input dict from input key to a list of artifacts, including:
        - model_export: exported model from trainer.
        - model_blessing: model blessing path from evaluator.
      output_dict: Output dict from key to a list of artifacts, including:
        - model_push: A list of 'ModelPushPath' artifact of size one. It will
          include the model in this push execution if the model was pushed.
      exec_properties: Mostly a passthrough input dict for
        tfx.components.Pusher.executor.  custom_config.bigquery_serving_args is
        consumed by this class, including:
        - bq_dataset_id: ID of the dataset you're creating or replacing
        - model_name: name of the model you're creating or replacing
        - project_id: GCP project where the model will be stored. It is also
          the project where the query is executed unless a compute_project_id
          is provided.
        - compute_project_id: GCP project where the query is executed. If not
          provided, the query is executed in project_id.
        For the full set of parameters supported by
        Big Query ML, refer to https://cloud.google.com/bigquery-ml/

    Returns:
      None
    Raises:
      ValueError:
        If bigquery_serving_args is not in exec_properties.custom_config.
        If pipeline_root is not 'gs://...'
      RuntimeError: if the Big Query job failed.

    Example usage:
      from tfx.extensions.google_cloud_big_query.pusher import executor

      pusher = Pusher(
        model=trainer.outputs['model'],
        model_blessing=evaluator.outputs['blessing'],
        custom_executor_spec=executor_spec.ExecutorClassSpec(executor.Executor),
        custom_config={
          'bigquery_serving_args': {
            'model_name': 'your_model_name',
            'project_id': 'your_gcp_storage_project',
            'bq_dataset_id': 'your_dataset_id',
            'compute_project_id': 'your_gcp_compute_project',
          },
        },
      )
    """
        self._log_startup(input_dict, output_dict, exec_properties)
        model_push = artifact_utils.get_single_instance(
            output_dict[standard_component_specs.PUSHED_MODEL_KEY])
        if not self.CheckBlessing(input_dict):
            self._MarkNotPushed(model_push)
            return

        custom_config = json_utils.loads(
            exec_properties.get(_CUSTOM_CONFIG_KEY, 'null'))
        if custom_config is not None and not isinstance(custom_config, Dict):
            raise ValueError(
                'custom_config in execution properties needs to be a '
                'dict.')

        bigquery_serving_args = custom_config.get(SERVING_ARGS_KEY)
        # if configuration is missing error out
        if bigquery_serving_args is None:
            raise ValueError('Big Query ML configuration was not provided')

        bq_model_uri = '.'.join([
            bigquery_serving_args[_PROJECT_ID_KEY],
            bigquery_serving_args[_BQ_DATASET_ID_KEY],
            bigquery_serving_args[_MODEL_NAME_KEY],
        ])

        # Deploy the model.
        io_utils.copy_dir(src=self.GetModelPath(input_dict),
                          dst=model_push.uri)
        model_path = model_push.uri
        if not model_path.startswith(_GCS_PREFIX):
            raise ValueError(
                'pipeline_root must be gs:// for BigQuery ML Pusher.')

        logging.info(
            'Deploying the model to BigQuery ML for serving: %s from %s',
            bigquery_serving_args, model_path)

        query = _BQML_CREATE_OR_REPLACE_MODEL_QUERY_TEMPLATE.format(
            model_uri=bq_model_uri, model_path=model_path)

        # TODO(zhitaoli): Refactor the executor_class_path creation into a common
        # utility function.
        executor_class_path = '%s.%s' % (self.__class__.__module__,
                                         self.__class__.__name__)
        with telemetry_utils.scoped_labels(
            {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}):
            default_query_job_config = bigquery.job.QueryJobConfig(
                labels=telemetry_utils.make_labels_dict())
        # TODO(b/181368842) Add integration test for BQML Pusher + Managed Pipeline
        project_id = (bigquery_serving_args.get(_COMPUTE_PROJECT_ID_KEY)
                      or bigquery_serving_args[_PROJECT_ID_KEY])
        client = bigquery.Client(
            default_query_job_config=default_query_job_config,
            project=project_id)

        try:
            query_job = client.query(query)
            query_job.result()  # Waits for the query to finish
        except Exception as e:
            raise RuntimeError('BigQuery ML Push failed: {}'.format(e)) from e

        logging.info('Successfully deployed model %s serving from %s',
                     bq_model_uri, model_path)

        # Setting the push_destination to bigquery uri
        self._MarkPushed(model_push, pushed_destination=bq_model_uri)
コード例 #10
0
ファイル: executor.py プロジェクト: jay90099/tfx
    def Do(self, input_dict: Dict[str, List[types.Artifact]],
           output_dict: Dict[str, List[types.Artifact]],
           exec_properties: Dict[str, Any]) -> None:
        """Runs batch inference on a given model with given input examples.

    This function creates a new model (if necessary) and a new model version
    before inference, and cleans up resources after inference. It provides
    re-executability as it cleans up (only) the model resources that are created
    during the process even inference job failed.

    Args:
      input_dict: Input dict from input key to a list of Artifacts.
        - examples: examples for inference.
        - model: exported model.
        - model_blessing: model blessing result
      output_dict: Output dict from output key to a list of Artifacts.
        - output: bulk inference results.
      exec_properties: A dict of execution properties.
        - data_spec: JSON string of bulk_inferrer_pb2.DataSpec instance.
        - custom_config: custom_config.ai_platform_serving_args need to contain
          the serving job parameters sent to Google Cloud AI Platform. For the
          full set of parameters, refer to
          https://cloud.google.com/ml-engine/reference/rest/v1/projects.models

    Returns:
      None
    """
        self._log_startup(input_dict, output_dict, exec_properties)

        if output_dict.get('inference_result'):
            inference_result = artifact_utils.get_single_instance(
                output_dict['inference_result'])
        else:
            inference_result = None
        if output_dict.get('output_examples'):
            output_examples = artifact_utils.get_single_instance(
                output_dict['output_examples'])
        else:
            output_examples = None

        if 'examples' not in input_dict:
            raise ValueError('`examples` is missing in input dict.')
        if 'model' not in input_dict:
            raise ValueError('Input models are not valid, model '
                             'need to be specified.')
        if 'model_blessing' in input_dict:
            model_blessing = artifact_utils.get_single_instance(
                input_dict['model_blessing'])
            if not model_utils.is_model_blessed(model_blessing):
                logging.info('Model on %s was not blessed', model_blessing.uri)
                return
        else:
            logging.info(
                'Model blessing is not provided, exported model will be '
                'used.')
        if _CUSTOM_CONFIG_KEY not in exec_properties:
            raise ValueError(
                'Input exec properties are not valid, {} '
                'need to be specified.'.format(_CUSTOM_CONFIG_KEY))

        custom_config = json_utils.loads(
            exec_properties.get(_CUSTOM_CONFIG_KEY, 'null'))
        if custom_config is not None and not isinstance(custom_config, Dict):
            raise ValueError(
                'custom_config in execution properties needs to be a '
                'dict.')
        ai_platform_serving_args = custom_config.get(SERVING_ARGS_KEY)
        if not ai_platform_serving_args:
            raise ValueError(
                '`ai_platform_serving_args` is missing in `custom_config`')
        service_name, api_version = runner.get_service_name_and_api_version(
            ai_platform_serving_args)
        executor_class_path = '%s.%s' % (self.__class__.__module__,
                                         self.__class__.__name__)
        with telemetry_utils.scoped_labels(
            {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}):
            job_labels = telemetry_utils.make_labels_dict()
        model = artifact_utils.get_single_instance(input_dict['model'])
        model_path = path_utils.serving_model_path(
            model.uri, path_utils.is_old_model_artifact(model))
        logging.info('Use exported model from %s.', model_path)
        # Use model artifact uri to generate model version to guarantee the
        # 1:1 mapping from model version to model.
        model_version = 'version_' + hashlib.sha256(
            model.uri.encode()).hexdigest()
        inference_spec = self._get_inference_spec(model_path, model_version,
                                                  ai_platform_serving_args)
        data_spec = bulk_inferrer_pb2.DataSpec()
        proto_utils.json_to_proto(exec_properties['data_spec'], data_spec)
        output_example_spec = bulk_inferrer_pb2.OutputExampleSpec()
        if exec_properties.get('output_example_spec'):
            proto_utils.json_to_proto(exec_properties['output_example_spec'],
                                      output_example_spec)
        endpoint = custom_config.get(constants.ENDPOINT_ARGS_KEY)
        if endpoint and 'regions' in ai_platform_serving_args:
            raise ValueError(
                '`endpoint` and `ai_platform_serving_args.regions` cannot be set simultaneously'
            )
        api = discovery.build(
            service_name,
            api_version,
            requestBuilder=telemetry_utils.TFXHttpRequest,
            client_options=client_options.ClientOptions(api_endpoint=endpoint),
        )
        new_model_endpoint_created = False
        try:
            new_model_endpoint_created = runner.create_model_for_aip_prediction_if_not_exist(
                job_labels, ai_platform_serving_args, api)
            runner.deploy_model_for_aip_prediction(
                serving_path=model_path,
                model_version_name=model_version,
                ai_platform_serving_args=ai_platform_serving_args,
                api=api,
                labels=job_labels,
                skip_model_endpoint_creation=True,
                set_default=False,
            )
            self._run_model_inference(data_spec, output_example_spec,
                                      input_dict['examples'], output_examples,
                                      inference_result, inference_spec)
        except Exception as e:
            logging.error(
                'Error in executing CloudAIBulkInferrerComponent: %s', str(e))
            raise
        finally:
            # Guarantee newly created resources are cleaned up even if the inference
            # job failed.

            # Clean up the newly deployed model.
            runner.delete_model_from_aip_if_exists(
                model_version_name=model_version,
                ai_platform_serving_args=ai_platform_serving_args,
                api=api,
                delete_model_endpoint=new_model_endpoint_created)
コード例 #11
0
ファイル: executor.py プロジェクト: jay90099/tfx
  def Do(self, input_dict: Dict[str, List[types.Artifact]],
         output_dict: Dict[str, List[types.Artifact]],
         exec_properties: Dict[str, Any]):
    """Overrides the tfx_pusher_executor.

    Args:
      input_dict: Input dict from input key to a list of artifacts, including:
        - model_export: exported model from trainer.
        - model_blessing: model blessing path from evaluator.
      output_dict: Output dict from key to a list of artifacts, including:
        - model_push: A list of 'ModelPushPath' artifact of size one. It will
          include the model in this push execution if the model was pushed.
      exec_properties: Mostly a passthrough input dict for
        tfx.components.Pusher.executor.  The following keys in `custom_config`
        are consumed by this class:
        - ai_platform_serving_args: For the full set of parameters supported
          by
          - Google Cloud AI Platform, refer to
          https://cloud.google.com/ml-engine/reference/rest/v1/projects.models.versions#Version.
          - Google Cloud Vertex AI, refer to
          https://googleapis.dev/python/aiplatform/latest/aiplatform.html?highlight=deploy#google.cloud.aiplatform.Model.deploy
        - endpoint: Optional endpoint override.
          - For Google Cloud AI Platform, this should be in format of
            `https://[region]-ml.googleapis.com`. Default to global endpoint if
            not set. Using regional endpoint is recommended by Cloud AI
            Platform. When set, 'regions' key in ai_platform_serving_args cannot
            be set. For more details, please see
            https://cloud.google.com/ai-platform/prediction/docs/regional-endpoints#using_regional_endpoints
          - For Google Cloud Vertex AI, this should be just be `region` (e.g.
            'us-central1'). For available regions, please see
            https://cloud.google.com/vertex-ai/docs/general/locations

    Raises:
      ValueError:
        If ai_platform_serving_args is not in exec_properties.custom_config.
        If Serving model path does not start with gs://.
        If 'endpoint' and 'regions' are set simultaneously.
      RuntimeError: if the Google Cloud AI Platform training job failed.
    """
    self._log_startup(input_dict, output_dict, exec_properties)

    custom_config = json_utils.loads(
        exec_properties.get(_CUSTOM_CONFIG_KEY, 'null'))
    if custom_config is not None and not isinstance(custom_config, Dict):
      raise ValueError('custom_config in execution properties needs to be a '
                       'dict.')
    ai_platform_serving_args = custom_config.get(constants.SERVING_ARGS_KEY)
    if not ai_platform_serving_args:
      raise ValueError(
          '\'ai_platform_serving_args\' is missing in \'custom_config\'')
    model_push = artifact_utils.get_single_instance(
        output_dict[standard_component_specs.PUSHED_MODEL_KEY])
    if not self.CheckBlessing(input_dict):
      self._MarkNotPushed(model_push)
      return

    # Deploy the model.
    io_utils.copy_dir(src=self.GetModelPath(input_dict), dst=model_push.uri)
    model_path = model_push.uri

    executor_class_path = '%s.%s' % (self.__class__.__module__,
                                     self.__class__.__name__)
    with telemetry_utils.scoped_labels(
        {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}):
      job_labels = telemetry_utils.make_labels_dict()

    enable_vertex = custom_config.get(constants.ENABLE_VERTEX_KEY)
    if enable_vertex:
      if custom_config.get(constants.ENDPOINT_ARGS_KEY):
        deprecation_utils.warn_deprecated(
            '\'endpoint\' is deprecated. Please use'
            '\'ai_platform_vertex_region\' instead.'
        )
      if 'regions' in ai_platform_serving_args:
        deprecation_utils.warn_deprecated(
            '\'ai_platform_serving_args.regions\' is deprecated. Please use'
            '\'ai_platform_vertex_region\' instead.'
        )
      endpoint_region = custom_config.get(constants.VERTEX_REGION_KEY)
      # TODO(jjong): Introduce Versioning.
      # Note that we're adding "v" prefix as Cloud AI Prediction only allows the
      # version name that starts with letters, and contains letters, digits,
      # underscore only.
      model_name = 'v{}'.format(int(time.time()))
      container_image_uri = custom_config.get(
          constants.VERTEX_CONTAINER_IMAGE_URI_KEY)

      pushed_model_path = runner.deploy_model_for_aip_prediction(
          serving_container_image_uri=container_image_uri,
          model_version_name=model_name,
          ai_platform_serving_args=ai_platform_serving_args,
          endpoint_region=endpoint_region,
          labels=job_labels,
          serving_path=model_path,
          enable_vertex=True,
      )

      self._MarkPushed(
          model_push,
          pushed_destination=pushed_model_path)

    else:
      endpoint = custom_config.get(constants.ENDPOINT_ARGS_KEY)
      if endpoint and 'regions' in ai_platform_serving_args:
        raise ValueError(
            '\'endpoint\' and \'ai_platform_serving_args.regions\' cannot be set simultaneously'
        )
      # TODO(jjong): Introduce Versioning.
      # Note that we're adding "v" prefix as Cloud AI Prediction only allows the
      # version name that starts with letters, and contains letters, digits,
      # underscore only.
      model_version = 'v{}'.format(int(time.time()))
      endpoint = endpoint or runner.DEFAULT_ENDPOINT
      service_name, api_version = runner.get_service_name_and_api_version(
          ai_platform_serving_args)
      api = discovery.build(
          service_name,
          api_version,
          requestBuilder=telemetry_utils.TFXHttpRequest,
          client_options=client_options.ClientOptions(api_endpoint=endpoint),
      )
      pushed_model_version_path = runner.deploy_model_for_aip_prediction(
          serving_path=model_path,
          model_version_name=model_version,
          ai_platform_serving_args=ai_platform_serving_args,
          api=api,
          labels=job_labels,
      )

      self._MarkPushed(
          model_push,
          pushed_destination=pushed_model_version_path,
          pushed_version=model_version)
コード例 #12
0
  def create_training_job(self, input_dict: Dict[str, List[types.Artifact]],
                          output_dict: Dict[str, List[types.Artifact]],
                          exec_properties: Dict[str, Any],
                          executor_class_path: str, job_args: Dict[str, Any],
                          job_id: Optional[str]) -> Dict[str, Any]:
    """Get training args for runner._launch_aip_training.

    The training args contain the inputs/outputs/exec_properties to the
    tfx.scripts.run_executor module.

    Args:
      input_dict: Passthrough input dict for tfx.components.Trainer.executor.
      output_dict: Passthrough input dict for tfx.components.Trainer.executor.
      exec_properties: Passthrough input dict for
        tfx.components.Trainer.executor.
      executor_class_path: class path for TFX core default trainer.
      job_args: CustomJob for Vertex AI custom training. See
        https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.customJobs#CustomJob
          for the detailed schema.
        [Deprecated]: job_args also support specifying only the CustomJobSpec
          instead of CustomJob. However, this functionality is deprecated.
      job_id: Display name for AI Platform (Unified) custom training job. If not
        supplied, system-determined unique ID is given. Refer to
        https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.customJobs#CustomJob

    Returns:
      A dict containing the Vertex AI CustomJob
    """
    job_args = job_args.copy()
    if job_args.get('job_spec'):
      custom_job_spec = job_args['job_spec']
    else:
      logging.warn(
          'job_spec key was not found. Parsing job_args as CustomJobSpec instead'
      )
      custom_job_spec = job_args

    container_command = self.generate_container_command(input_dict, output_dict,
                                                        exec_properties,
                                                        executor_class_path)

    if not custom_job_spec.get('worker_pool_specs'):
      custom_job_spec['worker_pool_specs'] = [{}]

    for worker_pool_spec in custom_job_spec['worker_pool_specs']:
      if not worker_pool_spec.get('container_spec'):
        worker_pool_spec['container_spec'] = {
            'image_uri': _TFX_IMAGE,
        }

      # Always use our own entrypoint instead of relying on container default.
      if 'command' in worker_pool_spec['container_spec']:
        logging.warn('Overriding custom value of container_spec.command')
      worker_pool_spec['container_spec']['command'] = container_command

    # 'tfx_YYYYmmddHHMMSS_xxxxxxxx' is the default job display name if not
    # explicitly specified.
    job_id = job_args.get('display_name', job_id)
    job_id = job_id or 'tfx_{}_{}'.format(
        datetime.datetime.now().strftime('%Y%m%d%H%M%S'),
        '%08x' % random.getrandbits(32))

    with telemetry_utils.scoped_labels(
        {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}):
      job_labels = telemetry_utils.make_labels_dict()
    job_labels.update(job_args.get('labels', {}))

    encryption_spec = job_args.get('encryption_spec', {})

    custom_job = {
        'display_name': job_id,
        'job_spec': custom_job_spec,
        'labels': job_labels,
        'encryption_spec': encryption_spec,
    }

    return custom_job