def setUp(self): super().setUp() self._output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) self._project_id = '12345' self._mock_api_client = mock.Mock() self._inputs = {} self._outputs = {} self._training_inputs = { 'project': self._project_id, } self._job_id = 'my_jobid' # Dict format of exec_properties. custom_config needs to be serialized # before being passed into start_cloud_training function. self._exec_properties = { 'custom_config': { executor.TRAINING_ARGS_KEY: self._training_inputs, }, } self._model_name = 'model_name' self._ai_platform_serving_args = { 'model_name': self._model_name, 'project_id': self._project_id, } self._executor_class_path = 'my.executor.Executor' with telemetry_utils.scoped_labels( {telemetry_utils.LABEL_TFX_EXECUTOR: self._executor_class_path}): self._job_labels = telemetry_utils.make_labels_dict()
def testDoBlessed_Vertex(self, mock_runner): endpoint_uri = 'projects/project_id/locations/us-central1/endpoints/12345' mock_runner.deploy_model_for_aip_prediction.return_value = endpoint_uri self._model_blessing.uri = os.path.join(self._source_data_dir, 'model_validator/blessed') self._model_blessing.set_int_custom_property('blessed', 1) self._executor.Do(self._input_dict, self._output_dict, self._serialize_custom_config_under_test_vertex()) executor_class_path = '%s.%s' % (self._executor.__class__.__module__, self._executor.__class__.__name__) with telemetry_utils.scoped_labels( {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}): job_labels = telemetry_utils.make_labels_dict() mock_runner.deploy_model_for_aip_prediction.assert_called_once_with( serving_container_image_uri=self._container_image_uri_vertex, model_version_name=mock.ANY, ai_platform_serving_args=mock.ANY, labels=job_labels, serving_path=self._model_push.uri, endpoint_region='us-central1', enable_vertex=True, ) self.assertPushed() self.assertEqual( self._model_push.get_string_custom_property('pushed_destination'), endpoint_uri)
def testDoSkippedModelCreation(self, mock_runner, mock_run_model_inference, _): input_dict = { 'examples': [self._examples], 'model': [self._model], 'model_blessing': [self._model_blessing], } output_dict = { 'inference_result': [self._inference_result], } ai_platform_serving_args = { 'model_name': 'model_name', 'project_id': 'project_id' } # Create exe properties. exec_properties = { 'data_spec': proto_utils.proto_to_json(bulk_inferrer_pb2.DataSpec()), 'custom_config': json_utils.dumps( {executor.SERVING_ARGS_KEY: ai_platform_serving_args}), } mock_runner.get_service_name_and_api_version.return_value = ('ml', 'v1') mock_runner.create_model_for_aip_prediction_if_not_exist.return_value = False # Run executor. bulk_inferrer = executor.Executor(self._context) bulk_inferrer.Do(input_dict, output_dict, exec_properties) ai_platform_prediction_model_spec = ( model_spec_pb2.AIPlatformPredictionModelSpec( project_id='project_id', model_name='model_name', version_name=self._model_version)) ai_platform_prediction_model_spec.use_serialization_config = True inference_endpoint = model_spec_pb2.InferenceSpecType() inference_endpoint.ai_platform_prediction_model_spec.CopyFrom( ai_platform_prediction_model_spec) mock_run_model_inference.assert_called_once_with(mock.ANY, mock.ANY, mock.ANY, mock.ANY, mock.ANY, inference_endpoint) executor_class_path = '%s.%s' % (bulk_inferrer.__class__.__module__, bulk_inferrer.__class__.__name__) with telemetry_utils.scoped_labels( {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}): job_labels = telemetry_utils.make_labels_dict() mock_runner.deploy_model_for_aip_prediction.assert_called_once_with( serving_path=path_utils.serving_model_path(self._model.uri), model_version_name=mock.ANY, ai_platform_serving_args=ai_platform_serving_args, labels=job_labels, api=mock.ANY, skip_model_endpoint_creation=True, set_default=False) mock_runner.delete_model_from_aip_if_exists.assert_called_once_with( model_version_name=mock.ANY, ai_platform_serving_args=ai_platform_serving_args, api=mock.ANY, delete_model_endpoint=False)
def testDoBlessed(self, mock_runner, _): self._model_blessing.uri = os.path.join(self._source_data_dir, 'model_validator/blessed') self._model_blessing.set_int_custom_property('blessed', 1) mock_runner.get_service_name_and_api_version.return_value = ('ml', 'v1') version = self._model_push.get_string_custom_property('pushed_version') mock_runner.deploy_model_for_aip_prediction.return_value = ( 'projects/project_id/models/model_name/versions/{}'.format(version) ) self._executor.Do(self._input_dict, self._output_dict, self._serialize_custom_config_under_test()) executor_class_path = '%s.%s' % (self._executor.__class__.__module__, self._executor.__class__.__name__) with telemetry_utils.scoped_labels( {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}): job_labels = telemetry_utils.make_labels_dict() mock_runner.deploy_model_for_aip_prediction.assert_called_once_with( serving_path=self._model_push.uri, model_version_name=mock.ANY, ai_platform_serving_args=mock.ANY, api=mock.ANY, labels=job_labels, ) self.assertPushed() self.assertEqual( self._model_push.get_string_custom_property('pushed_destination'), 'projects/project_id/models/model_name/versions/{}'.format( version))
def testScopedLabels(self): """Test for scoped_labels.""" orig_labels = telemetry_utils.make_labels_dict() with telemetry_utils.scoped_labels({'foo': 'bar'}): self.assertDictEqual(telemetry_utils.make_labels_dict(), dict({'foo': 'bar'}, **orig_labels)) with telemetry_utils.scoped_labels({ telemetry_utils.LABEL_TFX_EXECUTOR: 'custom_component.custom_executor' }): self.assertDictEqual( telemetry_utils.make_labels_dict(), dict( { 'foo': 'bar', telemetry_utils.LABEL_TFX_EXECUTOR: 'third_party_executor' }, **orig_labels)) with telemetry_utils.scoped_labels({ telemetry_utils.LABEL_TFX_EXECUTOR: 'tfx.components.example_gen.import_example_gen.executor.Executor' }): self.assertDictEqual( telemetry_utils.make_labels_dict(), dict( { 'foo': 'bar', telemetry_utils.LABEL_TFX_EXECUTOR: # Label is normalized. 'tfx-components-example_gen-import_example_gen-executor-executor' }, **orig_labels)) with telemetry_utils.scoped_labels({ telemetry_utils.LABEL_TFX_EXECUTOR: 'tfx.extensions.google_cloud_big_query.example_gen.executor.Executor' }): self.assertDictEqual( telemetry_utils.make_labels_dict(), dict( { 'foo': 'bar', telemetry_utils.LABEL_TFX_EXECUTOR: # Label is normalized. 'tfx-extensions-google_cloud_big_query-example_gen-executor-exec' }, **orig_labels))
def create_training_job(self, input_dict: Dict[str, List[types.Artifact]], output_dict: Dict[str, List[types.Artifact]], exec_properties: Dict[str, Any], executor_class_path: str, job_args: Dict[str, Any], job_id: Optional[str]) -> Dict[str, Any]: """Get training args for runner._launch_aip_training. The training args contain the inputs/outputs/exec_properties to the tfx.scripts.run_executor module. Args: input_dict: Passthrough input dict for tfx.components.Trainer.executor. output_dict: Passthrough input dict for tfx.components.Trainer.executor. exec_properties: Passthrough input dict for tfx.components.Trainer.executor. executor_class_path: class path for TFX core default trainer. job_args: Training input argument for AI Platform training job. 'pythonModule', 'pythonVersion' and 'runtimeVersion' will be inferred. For the full set of parameters, refer to https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#TrainingInput job_id: Job ID for AI Platform Training job. If not supplied, system-determined unique ID is given. Refer to https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#resource-job Returns: A dict containing the training arguments """ training_inputs = job_args.copy() container_command = self.generate_container_command(input_dict, output_dict, exec_properties, executor_class_path) if not training_inputs.get('masterConfig'): training_inputs['masterConfig'] = { 'imageUri': _TFX_IMAGE, } # Always use our own entrypoint instead of relying on container default. if 'containerCommand' in training_inputs['masterConfig']: logging.warn('Overriding custom value of containerCommand') training_inputs['masterConfig']['containerCommand'] = container_command with telemetry_utils.scoped_labels( {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}): job_labels = telemetry_utils.make_labels_dict() # 'tfx_YYYYmmddHHMMSS' is the default job ID if not explicitly specified. job_id = job_id or 'tfx_{}'.format( datetime.datetime.now().strftime('%Y%m%d%H%M%S')) caip_job = { 'job_id': job_id, 'training_input': training_inputs, 'labels': job_labels } return caip_job
def ReadFromBigQuery(pipeline: beam.Pipeline, query: str) -> beam.pvalue.PCollection: """Read data from BigQuery. Args: pipeline: Beam pipeline. query: A BigQuery sql string. Returns: PCollection of dict. """ return (pipeline | 'ReadFromBigQuery' >> bigquery.ReadFromBigQuery( query=query, use_standard_sql=True, bigquery_job_labels=telemetry_utils.make_labels_dict()))
def testDoBlessedOnRegionalEndpoint_Vertex(self, mock_runner): endpoint_uri = 'projects/project_id/locations/us-west1/endpoints/12345' mock_runner.deploy_model_for_aip_prediction.return_value = endpoint_uri self._exec_properties_vertex = { 'custom_config': { constants.SERVING_ARGS_KEY: { 'model_name': 'model_name', 'project_id': 'project_id' }, constants.VERTEX_CONTAINER_IMAGE_URI_KEY: self._container_image_uri_vertex, constants.ENABLE_VERTEX_KEY: True, constants.VERTEX_REGION_KEY: 'us-west1', }, } self._model_blessing.uri = os.path.join(self._source_data_dir, 'model_validator/blessed') self._model_blessing.set_int_custom_property('blessed', 1) self._executor.Do(self._input_dict, self._output_dict, self._serialize_custom_config_under_test_vertex()) executor_class_path = '%s.%s' % (self._executor.__class__.__module__, self._executor.__class__.__name__) with telemetry_utils.scoped_labels( {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}): job_labels = telemetry_utils.make_labels_dict() mock_runner.deploy_model_for_aip_prediction.assert_called_once_with( serving_path=self._model_push.uri, model_version_name=mock.ANY, ai_platform_serving_args=mock.ANY, labels=job_labels, serving_container_image_uri=self._container_image_uri_vertex, endpoint_region='us-west1', enable_vertex=True, ) self.assertPushed() self.assertEqual( self._model_push.get_string_custom_property('pushed_destination'), endpoint_uri)
def Do(self, input_dict: Dict[str, List[types.Artifact]], output_dict: Dict[str, List[types.Artifact]], exec_properties: Dict[str, Any]): """Overrides the tfx_pusher_executor. Args: input_dict: Input dict from input key to a list of artifacts, including: - model_export: exported model from trainer. - model_blessing: model blessing path from evaluator. output_dict: Output dict from key to a list of artifacts, including: - model_push: A list of 'ModelPushPath' artifact of size one. It will include the model in this push execution if the model was pushed. exec_properties: Mostly a passthrough input dict for tfx.components.Pusher.executor. custom_config.bigquery_serving_args is consumed by this class, including: - bq_dataset_id: ID of the dataset you're creating or replacing - model_name: name of the model you're creating or replacing - project_id: GCP project where the model will be stored. It is also the project where the query is executed unless a compute_project_id is provided. - compute_project_id: GCP project where the query is executed. If not provided, the query is executed in project_id. For the full set of parameters supported by Big Query ML, refer to https://cloud.google.com/bigquery-ml/ Returns: None Raises: ValueError: If bigquery_serving_args is not in exec_properties.custom_config. If pipeline_root is not 'gs://...' RuntimeError: if the Big Query job failed. Example usage: from tfx.extensions.google_cloud_big_query.pusher import executor pusher = Pusher( model=trainer.outputs['model'], model_blessing=evaluator.outputs['blessing'], custom_executor_spec=executor_spec.ExecutorClassSpec(executor.Executor), custom_config={ 'bigquery_serving_args': { 'model_name': 'your_model_name', 'project_id': 'your_gcp_storage_project', 'bq_dataset_id': 'your_dataset_id', 'compute_project_id': 'your_gcp_compute_project', }, }, ) """ self._log_startup(input_dict, output_dict, exec_properties) model_push = artifact_utils.get_single_instance( output_dict[standard_component_specs.PUSHED_MODEL_KEY]) if not self.CheckBlessing(input_dict): self._MarkNotPushed(model_push) return custom_config = json_utils.loads( exec_properties.get(_CUSTOM_CONFIG_KEY, 'null')) if custom_config is not None and not isinstance(custom_config, Dict): raise ValueError( 'custom_config in execution properties needs to be a ' 'dict.') bigquery_serving_args = custom_config.get(SERVING_ARGS_KEY) # if configuration is missing error out if bigquery_serving_args is None: raise ValueError('Big Query ML configuration was not provided') bq_model_uri = '.'.join([ bigquery_serving_args[_PROJECT_ID_KEY], bigquery_serving_args[_BQ_DATASET_ID_KEY], bigquery_serving_args[_MODEL_NAME_KEY], ]) # Deploy the model. io_utils.copy_dir(src=self.GetModelPath(input_dict), dst=model_push.uri) model_path = model_push.uri if not model_path.startswith(_GCS_PREFIX): raise ValueError( 'pipeline_root must be gs:// for BigQuery ML Pusher.') logging.info( 'Deploying the model to BigQuery ML for serving: %s from %s', bigquery_serving_args, model_path) query = _BQML_CREATE_OR_REPLACE_MODEL_QUERY_TEMPLATE.format( model_uri=bq_model_uri, model_path=model_path) # TODO(zhitaoli): Refactor the executor_class_path creation into a common # utility function. executor_class_path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__) with telemetry_utils.scoped_labels( {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}): default_query_job_config = bigquery.job.QueryJobConfig( labels=telemetry_utils.make_labels_dict()) # TODO(b/181368842) Add integration test for BQML Pusher + Managed Pipeline project_id = (bigquery_serving_args.get(_COMPUTE_PROJECT_ID_KEY) or bigquery_serving_args[_PROJECT_ID_KEY]) client = bigquery.Client( default_query_job_config=default_query_job_config, project=project_id) try: query_job = client.query(query) query_job.result() # Waits for the query to finish except Exception as e: raise RuntimeError('BigQuery ML Push failed: {}'.format(e)) from e logging.info('Successfully deployed model %s serving from %s', bq_model_uri, model_path) # Setting the push_destination to bigquery uri self._MarkPushed(model_push, pushed_destination=bq_model_uri)
def Do(self, input_dict: Dict[str, List[types.Artifact]], output_dict: Dict[str, List[types.Artifact]], exec_properties: Dict[str, Any]) -> None: """Runs batch inference on a given model with given input examples. This function creates a new model (if necessary) and a new model version before inference, and cleans up resources after inference. It provides re-executability as it cleans up (only) the model resources that are created during the process even inference job failed. Args: input_dict: Input dict from input key to a list of Artifacts. - examples: examples for inference. - model: exported model. - model_blessing: model blessing result output_dict: Output dict from output key to a list of Artifacts. - output: bulk inference results. exec_properties: A dict of execution properties. - data_spec: JSON string of bulk_inferrer_pb2.DataSpec instance. - custom_config: custom_config.ai_platform_serving_args need to contain the serving job parameters sent to Google Cloud AI Platform. For the full set of parameters, refer to https://cloud.google.com/ml-engine/reference/rest/v1/projects.models Returns: None """ self._log_startup(input_dict, output_dict, exec_properties) if output_dict.get('inference_result'): inference_result = artifact_utils.get_single_instance( output_dict['inference_result']) else: inference_result = None if output_dict.get('output_examples'): output_examples = artifact_utils.get_single_instance( output_dict['output_examples']) else: output_examples = None if 'examples' not in input_dict: raise ValueError('`examples` is missing in input dict.') if 'model' not in input_dict: raise ValueError('Input models are not valid, model ' 'need to be specified.') if 'model_blessing' in input_dict: model_blessing = artifact_utils.get_single_instance( input_dict['model_blessing']) if not model_utils.is_model_blessed(model_blessing): logging.info('Model on %s was not blessed', model_blessing.uri) return else: logging.info( 'Model blessing is not provided, exported model will be ' 'used.') if _CUSTOM_CONFIG_KEY not in exec_properties: raise ValueError( 'Input exec properties are not valid, {} ' 'need to be specified.'.format(_CUSTOM_CONFIG_KEY)) custom_config = json_utils.loads( exec_properties.get(_CUSTOM_CONFIG_KEY, 'null')) if custom_config is not None and not isinstance(custom_config, Dict): raise ValueError( 'custom_config in execution properties needs to be a ' 'dict.') ai_platform_serving_args = custom_config.get(SERVING_ARGS_KEY) if not ai_platform_serving_args: raise ValueError( '`ai_platform_serving_args` is missing in `custom_config`') service_name, api_version = runner.get_service_name_and_api_version( ai_platform_serving_args) executor_class_path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__) with telemetry_utils.scoped_labels( {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}): job_labels = telemetry_utils.make_labels_dict() model = artifact_utils.get_single_instance(input_dict['model']) model_path = path_utils.serving_model_path( model.uri, path_utils.is_old_model_artifact(model)) logging.info('Use exported model from %s.', model_path) # Use model artifact uri to generate model version to guarantee the # 1:1 mapping from model version to model. model_version = 'version_' + hashlib.sha256( model.uri.encode()).hexdigest() inference_spec = self._get_inference_spec(model_path, model_version, ai_platform_serving_args) data_spec = bulk_inferrer_pb2.DataSpec() proto_utils.json_to_proto(exec_properties['data_spec'], data_spec) output_example_spec = bulk_inferrer_pb2.OutputExampleSpec() if exec_properties.get('output_example_spec'): proto_utils.json_to_proto(exec_properties['output_example_spec'], output_example_spec) endpoint = custom_config.get(constants.ENDPOINT_ARGS_KEY) if endpoint and 'regions' in ai_platform_serving_args: raise ValueError( '`endpoint` and `ai_platform_serving_args.regions` cannot be set simultaneously' ) api = discovery.build( service_name, api_version, requestBuilder=telemetry_utils.TFXHttpRequest, client_options=client_options.ClientOptions(api_endpoint=endpoint), ) new_model_endpoint_created = False try: new_model_endpoint_created = runner.create_model_for_aip_prediction_if_not_exist( job_labels, ai_platform_serving_args, api) runner.deploy_model_for_aip_prediction( serving_path=model_path, model_version_name=model_version, ai_platform_serving_args=ai_platform_serving_args, api=api, labels=job_labels, skip_model_endpoint_creation=True, set_default=False, ) self._run_model_inference(data_spec, output_example_spec, input_dict['examples'], output_examples, inference_result, inference_spec) except Exception as e: logging.error( 'Error in executing CloudAIBulkInferrerComponent: %s', str(e)) raise finally: # Guarantee newly created resources are cleaned up even if the inference # job failed. # Clean up the newly deployed model. runner.delete_model_from_aip_if_exists( model_version_name=model_version, ai_platform_serving_args=ai_platform_serving_args, api=api, delete_model_endpoint=new_model_endpoint_created)
def Do(self, input_dict: Dict[str, List[types.Artifact]], output_dict: Dict[str, List[types.Artifact]], exec_properties: Dict[str, Any]): """Overrides the tfx_pusher_executor. Args: input_dict: Input dict from input key to a list of artifacts, including: - model_export: exported model from trainer. - model_blessing: model blessing path from evaluator. output_dict: Output dict from key to a list of artifacts, including: - model_push: A list of 'ModelPushPath' artifact of size one. It will include the model in this push execution if the model was pushed. exec_properties: Mostly a passthrough input dict for tfx.components.Pusher.executor. The following keys in `custom_config` are consumed by this class: - ai_platform_serving_args: For the full set of parameters supported by - Google Cloud AI Platform, refer to https://cloud.google.com/ml-engine/reference/rest/v1/projects.models.versions#Version. - Google Cloud Vertex AI, refer to https://googleapis.dev/python/aiplatform/latest/aiplatform.html?highlight=deploy#google.cloud.aiplatform.Model.deploy - endpoint: Optional endpoint override. - For Google Cloud AI Platform, this should be in format of `https://[region]-ml.googleapis.com`. Default to global endpoint if not set. Using regional endpoint is recommended by Cloud AI Platform. When set, 'regions' key in ai_platform_serving_args cannot be set. For more details, please see https://cloud.google.com/ai-platform/prediction/docs/regional-endpoints#using_regional_endpoints - For Google Cloud Vertex AI, this should be just be `region` (e.g. 'us-central1'). For available regions, please see https://cloud.google.com/vertex-ai/docs/general/locations Raises: ValueError: If ai_platform_serving_args is not in exec_properties.custom_config. If Serving model path does not start with gs://. If 'endpoint' and 'regions' are set simultaneously. RuntimeError: if the Google Cloud AI Platform training job failed. """ self._log_startup(input_dict, output_dict, exec_properties) custom_config = json_utils.loads( exec_properties.get(_CUSTOM_CONFIG_KEY, 'null')) if custom_config is not None and not isinstance(custom_config, Dict): raise ValueError('custom_config in execution properties needs to be a ' 'dict.') ai_platform_serving_args = custom_config.get(constants.SERVING_ARGS_KEY) if not ai_platform_serving_args: raise ValueError( '\'ai_platform_serving_args\' is missing in \'custom_config\'') model_push = artifact_utils.get_single_instance( output_dict[standard_component_specs.PUSHED_MODEL_KEY]) if not self.CheckBlessing(input_dict): self._MarkNotPushed(model_push) return # Deploy the model. io_utils.copy_dir(src=self.GetModelPath(input_dict), dst=model_push.uri) model_path = model_push.uri executor_class_path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__) with telemetry_utils.scoped_labels( {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}): job_labels = telemetry_utils.make_labels_dict() enable_vertex = custom_config.get(constants.ENABLE_VERTEX_KEY) if enable_vertex: if custom_config.get(constants.ENDPOINT_ARGS_KEY): deprecation_utils.warn_deprecated( '\'endpoint\' is deprecated. Please use' '\'ai_platform_vertex_region\' instead.' ) if 'regions' in ai_platform_serving_args: deprecation_utils.warn_deprecated( '\'ai_platform_serving_args.regions\' is deprecated. Please use' '\'ai_platform_vertex_region\' instead.' ) endpoint_region = custom_config.get(constants.VERTEX_REGION_KEY) # TODO(jjong): Introduce Versioning. # Note that we're adding "v" prefix as Cloud AI Prediction only allows the # version name that starts with letters, and contains letters, digits, # underscore only. model_name = 'v{}'.format(int(time.time())) container_image_uri = custom_config.get( constants.VERTEX_CONTAINER_IMAGE_URI_KEY) pushed_model_path = runner.deploy_model_for_aip_prediction( serving_container_image_uri=container_image_uri, model_version_name=model_name, ai_platform_serving_args=ai_platform_serving_args, endpoint_region=endpoint_region, labels=job_labels, serving_path=model_path, enable_vertex=True, ) self._MarkPushed( model_push, pushed_destination=pushed_model_path) else: endpoint = custom_config.get(constants.ENDPOINT_ARGS_KEY) if endpoint and 'regions' in ai_platform_serving_args: raise ValueError( '\'endpoint\' and \'ai_platform_serving_args.regions\' cannot be set simultaneously' ) # TODO(jjong): Introduce Versioning. # Note that we're adding "v" prefix as Cloud AI Prediction only allows the # version name that starts with letters, and contains letters, digits, # underscore only. model_version = 'v{}'.format(int(time.time())) endpoint = endpoint or runner.DEFAULT_ENDPOINT service_name, api_version = runner.get_service_name_and_api_version( ai_platform_serving_args) api = discovery.build( service_name, api_version, requestBuilder=telemetry_utils.TFXHttpRequest, client_options=client_options.ClientOptions(api_endpoint=endpoint), ) pushed_model_version_path = runner.deploy_model_for_aip_prediction( serving_path=model_path, model_version_name=model_version, ai_platform_serving_args=ai_platform_serving_args, api=api, labels=job_labels, ) self._MarkPushed( model_push, pushed_destination=pushed_model_version_path, pushed_version=model_version)
def create_training_job(self, input_dict: Dict[str, List[types.Artifact]], output_dict: Dict[str, List[types.Artifact]], exec_properties: Dict[str, Any], executor_class_path: str, job_args: Dict[str, Any], job_id: Optional[str]) -> Dict[str, Any]: """Get training args for runner._launch_aip_training. The training args contain the inputs/outputs/exec_properties to the tfx.scripts.run_executor module. Args: input_dict: Passthrough input dict for tfx.components.Trainer.executor. output_dict: Passthrough input dict for tfx.components.Trainer.executor. exec_properties: Passthrough input dict for tfx.components.Trainer.executor. executor_class_path: class path for TFX core default trainer. job_args: CustomJob for Vertex AI custom training. See https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.customJobs#CustomJob for the detailed schema. [Deprecated]: job_args also support specifying only the CustomJobSpec instead of CustomJob. However, this functionality is deprecated. job_id: Display name for AI Platform (Unified) custom training job. If not supplied, system-determined unique ID is given. Refer to https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.customJobs#CustomJob Returns: A dict containing the Vertex AI CustomJob """ job_args = job_args.copy() if job_args.get('job_spec'): custom_job_spec = job_args['job_spec'] else: logging.warn( 'job_spec key was not found. Parsing job_args as CustomJobSpec instead' ) custom_job_spec = job_args container_command = self.generate_container_command(input_dict, output_dict, exec_properties, executor_class_path) if not custom_job_spec.get('worker_pool_specs'): custom_job_spec['worker_pool_specs'] = [{}] for worker_pool_spec in custom_job_spec['worker_pool_specs']: if not worker_pool_spec.get('container_spec'): worker_pool_spec['container_spec'] = { 'image_uri': _TFX_IMAGE, } # Always use our own entrypoint instead of relying on container default. if 'command' in worker_pool_spec['container_spec']: logging.warn('Overriding custom value of container_spec.command') worker_pool_spec['container_spec']['command'] = container_command # 'tfx_YYYYmmddHHMMSS_xxxxxxxx' is the default job display name if not # explicitly specified. job_id = job_args.get('display_name', job_id) job_id = job_id or 'tfx_{}_{}'.format( datetime.datetime.now().strftime('%Y%m%d%H%M%S'), '%08x' % random.getrandbits(32)) with telemetry_utils.scoped_labels( {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}): job_labels = telemetry_utils.make_labels_dict() job_labels.update(job_args.get('labels', {})) encryption_spec = job_args.get('encryption_spec', {}) custom_job = { 'display_name': job_id, 'job_spec': custom_job_spec, 'labels': job_labels, 'encryption_spec': encryption_spec, } return custom_job