def testImageVersion(self): self.assertEqual(version_utils.get_image_version('0.25.0'), '0.25.0') self.assertEqual(version_utils.get_image_version('0.25.0-rc1'), '0.25.0rc1') self.assertEqual(version_utils.get_image_version('0.25.0.dev20201101'), '0.25.0.dev20201101') self.assertEqual(version_utils.get_image_version('0.26.0.dev'), 'latest')
def testStartAIPTraining(self, mock_discovery): mock_discovery.build.return_value = self._mock_api_client self._setUpTrainingMocks() class_path = 'foo.bar.class' runner.start_aip_training(self._inputs, self._outputs, self._serialize_custom_config_under_test(), class_path, self._training_inputs, None) self._mock_create.assert_called_with(body=mock.ANY, parent='projects/{}'.format( self._project_id)) (_, kwargs) = self._mock_create.call_args body = kwargs['body'] default_image = 'gcr.io/tfx-oss-public/tfx:{}'.format( version_utils.get_image_version()) self.assertDictContainsSubset( { 'masterConfig': { 'imageUri': default_image, 'containerCommand': runner._CONTAINER_COMMAND + [ '--executor_class_path', class_path, '--inputs', '{}', '--outputs', '{}', '--exec-properties', '{"custom_config": ' '"{\\"ai_platform_training_args\\": {\\"project\\": \\"12345\\"' '}}"}' ], }, }, body['trainingInput']) self.assertStartsWith(body['jobId'], 'tfx_') self._mock_get.execute.assert_called_with()
def testStartAIPTraining_uCAIP(self, mock_gapic): mock_gapic.JobServiceClient.return_value = self._mock_api_client self._setUpUcaipTrainingMocks() class_path = 'foo.bar.class' region = 'us-central1' runner.start_aip_training(self._inputs, self._outputs, self._serialize_custom_config_under_test(), class_path, self._training_inputs, None, True, region) self._mock_create.assert_called_with( parent='projects/{}/locations/{}'.format(self._project_id, region), custom_job=mock.ANY) (_, kwargs) = self._mock_create.call_args body = kwargs['custom_job'] default_image = 'gcr.io/tfx-oss-public/tfx:{}'.format( version_utils.get_image_version()) self.assertDictContainsSubset( { 'worker_pool_specs': [ { 'container_spec': { 'image_uri': default_image, 'command': runner._CONTAINER_COMMAND + [ '--executor_class_path', class_path, '--inputs', '{}', '--outputs', '{}', '--exec-properties', '{"custom_config": ' '"{\\"ai_platform_training_args\\": ' '{\\"project\\": \\"12345\\"' '}}"}' ], }, }, ], }, body['job_spec']) self.assertStartsWith(body['display_name'], 'tfx_') self._mock_get.assert_called_with(name='ucaip_job_study_id')
from tfx.utils import telemetry_utils from tfx.utils import version_utils from google.protobuf import json_format from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import _KUBEFLOW_TFX_CMD = ( 'python', '-m', 'tfx.orchestration.kubeflow.v2.container.kubeflow_v2_run_executor') # Current schema version for the API proto. _SCHEMA_VERSION = 'v2alpha1' # Default TFX container image/commands to use in KubeflowV2DagRunner. _KUBEFLOW_TFX_IMAGE = 'gcr.io/tfx-oss-public/tfx:{}'.format( version_utils.get_image_version()) def _get_current_time(): """Gets the current timestamp.""" return datetime.datetime.now() class KubeflowV2DagRunnerConfig(pipeline_config.PipelineConfig): """Runtime configuration specific to execution on Kubeflow pipelines.""" def __init__(self, project_id: Text, display_name: Optional[Text] = None, default_image: Optional[Text] = None, default_commands: Optional[List[Text]] = None, **kwargs):