Esempio n. 1
0
 def testImageVersion(self):
     self.assertEqual(version_utils.get_image_version('0.25.0'), '0.25.0')
     self.assertEqual(version_utils.get_image_version('0.25.0-rc1'),
                      '0.25.0rc1')
     self.assertEqual(version_utils.get_image_version('0.25.0.dev20201101'),
                      '0.25.0.dev20201101')
     self.assertEqual(version_utils.get_image_version('0.26.0.dev'),
                      'latest')
Esempio n. 2
0
    def testStartAIPTraining(self, mock_discovery):
        mock_discovery.build.return_value = self._mock_api_client
        self._setUpTrainingMocks()

        class_path = 'foo.bar.class'

        runner.start_aip_training(self._inputs, self._outputs,
                                  self._serialize_custom_config_under_test(),
                                  class_path, self._training_inputs, None)

        self._mock_create.assert_called_with(body=mock.ANY,
                                             parent='projects/{}'.format(
                                                 self._project_id))
        (_, kwargs) = self._mock_create.call_args
        body = kwargs['body']

        default_image = 'gcr.io/tfx-oss-public/tfx:{}'.format(
            version_utils.get_image_version())
        self.assertDictContainsSubset(
            {
                'masterConfig': {
                    'imageUri':
                    default_image,
                    'containerCommand':
                    runner._CONTAINER_COMMAND + [
                        '--executor_class_path', class_path, '--inputs', '{}',
                        '--outputs', '{}', '--exec-properties',
                        '{"custom_config": '
                        '"{\\"ai_platform_training_args\\": {\\"project\\": \\"12345\\"'
                        '}}"}'
                    ],
                },
            }, body['trainingInput'])
        self.assertStartsWith(body['jobId'], 'tfx_')
        self._mock_get.execute.assert_called_with()
Esempio n. 3
0
    def testStartAIPTraining_uCAIP(self, mock_gapic):
        mock_gapic.JobServiceClient.return_value = self._mock_api_client
        self._setUpUcaipTrainingMocks()

        class_path = 'foo.bar.class'
        region = 'us-central1'

        runner.start_aip_training(self._inputs, self._outputs,
                                  self._serialize_custom_config_under_test(),
                                  class_path, self._training_inputs, None,
                                  True, region)

        self._mock_create.assert_called_with(
            parent='projects/{}/locations/{}'.format(self._project_id, region),
            custom_job=mock.ANY)
        (_, kwargs) = self._mock_create.call_args
        body = kwargs['custom_job']

        default_image = 'gcr.io/tfx-oss-public/tfx:{}'.format(
            version_utils.get_image_version())
        self.assertDictContainsSubset(
            {
                'worker_pool_specs': [
                    {
                        'container_spec': {
                            'image_uri':
                            default_image,
                            'command':
                            runner._CONTAINER_COMMAND + [
                                '--executor_class_path', class_path,
                                '--inputs', '{}', '--outputs', '{}',
                                '--exec-properties', '{"custom_config": '
                                '"{\\"ai_platform_training_args\\": '
                                '{\\"project\\": \\"12345\\"'
                                '}}"}'
                            ],
                        },
                    },
                ],
            }, body['job_spec'])
        self.assertStartsWith(body['display_name'], 'tfx_')
        self._mock_get.assert_called_with(name='ucaip_job_study_id')
Esempio n. 4
0
from tfx.utils import telemetry_utils
from tfx.utils import version_utils

from google.protobuf import json_format
from tensorflow.python.util import deprecation  # pylint: disable=g-direct-tensorflow-import

_KUBEFLOW_TFX_CMD = (
    'python', '-m',
    'tfx.orchestration.kubeflow.v2.container.kubeflow_v2_run_executor')

# Current schema version for the API proto.
_SCHEMA_VERSION = 'v2alpha1'

# Default TFX container image/commands to use in KubeflowV2DagRunner.
_KUBEFLOW_TFX_IMAGE = 'gcr.io/tfx-oss-public/tfx:{}'.format(
    version_utils.get_image_version())


def _get_current_time():
    """Gets the current timestamp."""
    return datetime.datetime.now()


class KubeflowV2DagRunnerConfig(pipeline_config.PipelineConfig):
    """Runtime configuration specific to execution on Kubeflow pipelines."""
    def __init__(self,
                 project_id: Text,
                 display_name: Optional[Text] = None,
                 default_image: Optional[Text] = None,
                 default_commands: Optional[List[Text]] = None,
                 **kwargs):