Ejemplo n.º 1
0
    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
           output_dict: Dict[Text, List[types.Artifact]],
           exec_properties: Dict[Text, Any]):
        """Overrides the tfx_pusher_executor.

    Args:
      input_dict: Input dict from input key to a list of artifacts, including:
        - model_export: exported model from trainer.
        - model_blessing: model blessing path from evaluator.
      output_dict: Output dict from key to a list of artifacts, including:
        - model_push: A list of 'ModelPushPath' artifact of size one. It will
          include the model in this push execution if the model was pushed.
      exec_properties: Mostly a passthrough input dict for
        tfx.components.Pusher.executor.  custom_config.ai_platform_serving_args
        is consumed by this class.  For the full set of parameters supported by
        Google Cloud AI Platform, refer to
        https://cloud.google.com/ml-engine/docs/tensorflow/deploying-models#creating_a_model_version.

    Raises:
      ValueError:
        If ai_platform_serving_args is not in exec_properties.custom_config.
        If Serving model path does not start with gs://.
      RuntimeError: if the Google Cloud AI Platform training job failed.
    """
        self._log_startup(input_dict, output_dict, exec_properties)
        model_push = artifact_utils.get_single_instance(
            output_dict[tfx_pusher_executor.PUSHED_MODEL_KEY])
        if not self.CheckBlessing(input_dict):
            self._MarkNotPushed(model_push)
            return

        model_export = artifact_utils.get_single_instance(
            input_dict[tfx_pusher_executor.MODEL_KEY])

        custom_config = json_utils.loads(
            exec_properties.get(_CUSTOM_CONFIG_KEY, 'null'))
        if custom_config is not None and not isinstance(custom_config, Dict):
            raise ValueError(
                'custom_config in execution properties needs to be a '
                'dict.')

        ai_platform_serving_args = custom_config.get(SERVING_ARGS_KEY)
        if not ai_platform_serving_args:
            raise ValueError(
                '\'ai_platform_serving_args\' is missing in \'custom_config\'')
        service_name, api_version = runner.get_service_name_and_api_version(
            ai_platform_serving_args)
        # Deploy the model.
        io_utils.copy_dir(src=path_utils.serving_model_path(model_export.uri),
                          dst=model_push.uri)
        model_path = model_push.uri
        # TODO(jjong): Introduce Versioning.
        # Note that we're adding "v" prefix as Cloud AI Prediction only allows the
        # version name that starts with letters, and contains letters, digits,
        # underscore only.
        model_version = 'v{}'.format(int(time.time()))
        executor_class_path = '%s.%s' % (self.__class__.__module__,
                                         self.__class__.__name__)
        with telemetry_utils.scoped_labels(
            {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}):
            job_labels = telemetry_utils.get_labels_dict()
        runner.deploy_model_for_aip_prediction(
            discovery.build(service_name, api_version),
            model_path,
            model_version,
            ai_platform_serving_args,
            job_labels,
        )

        self._MarkPushed(
            model_push,
            pushed_destination=_CAIP_MODEL_VERSION_PATH_FORMAT.format(
                project_id=ai_platform_serving_args['project_id'],
                model=ai_platform_serving_args['model_name'],
                version=model_version),
            pushed_version=model_version)
Ejemplo n.º 2
0
  def Do(self, input_dict: Dict[str, List[types.Artifact]],
         output_dict: Dict[str, List[types.Artifact]],
         exec_properties: Dict[str, Any]):
    """Overrides the tfx_pusher_executor.

    Args:
      input_dict: Input dict from input key to a list of artifacts, including:
        - model_export: exported model from trainer.
        - model_blessing: model blessing path from evaluator.
      output_dict: Output dict from key to a list of artifacts, including:
        - model_push: A list of 'ModelPushPath' artifact of size one. It will
          include the model in this push execution if the model was pushed.
      exec_properties: Mostly a passthrough input dict for
        tfx.components.Pusher.executor.  The following keys in `custom_config`
        are consumed by this class:
        - ai_platform_serving_args: For the full set of parameters supported
          by
          - Google Cloud AI Platform, refer to
          https://cloud.google.com/ml-engine/reference/rest/v1/projects.models.versions#Version.
          - Google Cloud Vertex AI, refer to
          https://googleapis.dev/python/aiplatform/latest/aiplatform.html?highlight=deploy#google.cloud.aiplatform.Model.deploy
        - endpoint: Optional endpoint override.
          - For Google Cloud AI Platform, this should be in format of
            `https://[region]-ml.googleapis.com`. Default to global endpoint if
            not set. Using regional endpoint is recommended by Cloud AI
            Platform. When set, 'regions' key in ai_platform_serving_args cannot
            be set. For more details, please see
            https://cloud.google.com/ai-platform/prediction/docs/regional-endpoints#using_regional_endpoints
          - For Google Cloud Vertex AI, this should be just be `region` (e.g.
            'us-central1'). For available regions, please see
            https://cloud.google.com/vertex-ai/docs/general/locations

    Raises:
      ValueError:
        If ai_platform_serving_args is not in exec_properties.custom_config.
        If Serving model path does not start with gs://.
        If 'endpoint' and 'regions' are set simultaneously.
      RuntimeError: if the Google Cloud AI Platform training job failed.
    """
    self._log_startup(input_dict, output_dict, exec_properties)

    custom_config = json_utils.loads(
        exec_properties.get(_CUSTOM_CONFIG_KEY, 'null'))
    if custom_config is not None and not isinstance(custom_config, Dict):
      raise ValueError('custom_config in execution properties needs to be a '
                       'dict.')
    ai_platform_serving_args = custom_config.get(constants.SERVING_ARGS_KEY)
    if not ai_platform_serving_args:
      raise ValueError(
          '\'ai_platform_serving_args\' is missing in \'custom_config\'')
    model_push = artifact_utils.get_single_instance(
        output_dict[standard_component_specs.PUSHED_MODEL_KEY])
    if not self.CheckBlessing(input_dict):
      self._MarkNotPushed(model_push)
      return

    # Deploy the model.
    io_utils.copy_dir(src=self.GetModelPath(input_dict), dst=model_push.uri)
    model_path = model_push.uri

    executor_class_path = '%s.%s' % (self.__class__.__module__,
                                     self.__class__.__name__)
    with telemetry_utils.scoped_labels(
        {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}):
      job_labels = telemetry_utils.make_labels_dict()

    enable_vertex = custom_config.get(constants.ENABLE_VERTEX_KEY)
    if enable_vertex:
      if custom_config.get(constants.ENDPOINT_ARGS_KEY):
        deprecation_utils.warn_deprecated(
            '\'endpoint\' is deprecated. Please use'
            '\'ai_platform_vertex_region\' instead.'
        )
      if 'regions' in ai_platform_serving_args:
        deprecation_utils.warn_deprecated(
            '\'ai_platform_serving_args.regions\' is deprecated. Please use'
            '\'ai_platform_vertex_region\' instead.'
        )
      endpoint_region = custom_config.get(constants.VERTEX_REGION_KEY)
      # TODO(jjong): Introduce Versioning.
      # Note that we're adding "v" prefix as Cloud AI Prediction only allows the
      # version name that starts with letters, and contains letters, digits,
      # underscore only.
      model_name = 'v{}'.format(int(time.time()))
      container_image_uri = custom_config.get(
          constants.VERTEX_CONTAINER_IMAGE_URI_KEY)

      pushed_model_path = runner.deploy_model_for_aip_prediction(
          serving_container_image_uri=container_image_uri,
          model_version_name=model_name,
          ai_platform_serving_args=ai_platform_serving_args,
          endpoint_region=endpoint_region,
          labels=job_labels,
          serving_path=model_path,
          enable_vertex=True,
      )

      self._MarkPushed(
          model_push,
          pushed_destination=pushed_model_path)

    else:
      endpoint = custom_config.get(constants.ENDPOINT_ARGS_KEY)
      if endpoint and 'regions' in ai_platform_serving_args:
        raise ValueError(
            '\'endpoint\' and \'ai_platform_serving_args.regions\' cannot be set simultaneously'
        )
      # TODO(jjong): Introduce Versioning.
      # Note that we're adding "v" prefix as Cloud AI Prediction only allows the
      # version name that starts with letters, and contains letters, digits,
      # underscore only.
      model_version = 'v{}'.format(int(time.time()))
      endpoint = endpoint or runner.DEFAULT_ENDPOINT
      service_name, api_version = runner.get_service_name_and_api_version(
          ai_platform_serving_args)
      api = discovery.build(
          service_name,
          api_version,
          requestBuilder=telemetry_utils.TFXHttpRequest,
          client_options=client_options.ClientOptions(api_endpoint=endpoint),
      )
      pushed_model_version_path = runner.deploy_model_for_aip_prediction(
          serving_path=model_path,
          model_version_name=model_version,
          ai_platform_serving_args=ai_platform_serving_args,
          api=api,
          labels=job_labels,
      )

      self._MarkPushed(
          model_push,
          pushed_destination=pushed_model_version_path,
          pushed_version=model_version)