Ejemplo n.º 1
0
    def run(self, input_dict: Dict[Text, List[types.Artifact]],
            output_dict: Dict[Text, List[types.Artifact]],
            exec_properties: Dict[Text,
                                  Any]) -> driver_output_pb2.DriverOutput:
        # Fake a constant span number, which, on prod, is usually calculated based
        # on date.
        span = 2
        with self._mlmd_connection as m:
            previous_output = inputs_utils.resolve_input_artifacts(
                m, self._self_output)

            # Version should be the max of existing version + 1 if span exists,
            # otherwise 0.
            version = 0
            if previous_output:
                version = max([
                    artifact.get_int_custom_property('version')
                    for artifact in previous_output['examples']
                    if artifact.get_int_custom_property('span') == span
                ] or [-1]) + 1

        output_example = copy.deepcopy(
            output_dict['output_examples'][0].mlmd_artifact)
        output_example.custom_properties['span'].int_value = span
        output_example.custom_properties['version'].int_value = version
        result = driver_output_pb2.DriverOutput()
        result.output_artifacts['output_examples'].artifacts.append(
            output_example)

        result.exec_properties['span'].int_value = span
        result.exec_properties['version'].int_value = version
        return result
Ejemplo n.º 2
0
Archivo: driver.py Proyecto: lre/tfx
    def run(self, input_dict: Dict[Text, List[types.Artifact]],
            output_dict: Dict[Text, List[types.Artifact]],
            exec_properties: Dict[Text,
                                  Any]) -> driver_output_pb2.DriverOutput:

        # Populate exec_properties
        result = driver_output_pb2.DriverOutput()
        # PipelineInfo and ComponentInfo are not actually used, two fake one are
        # created just to be compatable with the old API.
        pipeline_info = data_types.PipelineInfo('', '')
        component_info = data_types.ComponentInfo('', '', pipeline_info)
        exec_properties = self.resolve_exec_properties(exec_properties,
                                                       pipeline_info,
                                                       component_info)
        for k, v in exec_properties.items():
            if v is not None:
                common_utils.set_metadata_value(result.exec_properties[k], v)

        # Populate output_dict
        output_example = copy.deepcopy(
            output_dict[utils.EXAMPLES_KEY][0].mlmd_artifact)
        _update_output_artifact(exec_properties, output_example)
        result.output_artifacts[utils.EXAMPLES_KEY].artifacts.append(
            output_example)
        return result
Ejemplo n.º 3
0
    def run(
        self, execution_info: portable_data_types.ExecutionInfo
    ) -> driver_output_pb2.DriverOutput:

        # Populate exec_properties
        result = driver_output_pb2.DriverOutput()
        # PipelineInfo and ComponentInfo are not actually used, two fake one are
        # created just to be compatible with the old API.
        pipeline_info = data_types.PipelineInfo('', '')
        component_info = data_types.ComponentInfo('', '', pipeline_info)
        exec_properties = self.resolve_exec_properties(
            execution_info.exec_properties, pipeline_info, component_info)
        for k, v in exec_properties.items():
            if v is not None:
                data_types_utils.set_metadata_value(result.exec_properties[k],
                                                    v)

        # Populate output_dict
        output_example = copy.deepcopy(execution_info.output_dict[
            standard_component_specs.EXAMPLES_KEY][0].mlmd_artifact)
        update_output_artifact(exec_properties, output_example)
        result.output_artifacts[standard_component_specs.
                                EXAMPLES_KEY].artifacts.append(output_example)
        return result
Ejemplo n.º 4
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tfx.orchestration.portable.python_driver_operator."""

from typing import Any, Dict, List, Text

import tensorflow as tf
from tfx import types
from tfx.orchestration.portable import base_driver
from tfx.orchestration.portable import python_driver_operator
from tfx.proto.orchestration import driver_output_pb2
from tfx.proto.orchestration import executable_spec_pb2

_DEFAULT_DRIVER_OUTPUT = driver_output_pb2.DriverOutput()


class _FakeNoopDriver(base_driver.BaseDriver):

  def run(self, input_dict: Dict[Text, List[types.Artifact]],
          output_dict: Dict[Text, List[types.Artifact]],
          exec_properties: Dict[Text, Any]) -> driver_output_pb2.DriverOutput:
    return _DEFAULT_DRIVER_OUTPUT


class PythonDriverOperatorTest(tf.test.TestCase):

  def succeed(self):
    custom_driver_spec = (executable_spec_pb2.PythonClassExecutableSpec())
    custom_driver_spec.class_path = 'tfx.orchestration.portable.python_driver_operator._FakeNoopDriver'
Ejemplo n.º 5
0
    def testOverrideRegisterExecution(self):
        # Mock all real operations of driver / executor / MLMD accesses.
        mock_targets = (  # (cls, method, return_value)
            (beam_executor_operator.BeamExecutorOperator, '__init__', None),
            (beam_executor_operator.BeamExecutorOperator, 'run_executor',
             execution_result_pb2.ExecutorOutput()),
            (python_driver_operator.PythonDriverOperator, '__init__', None),
            (python_driver_operator.PythonDriverOperator, 'run_driver',
             driver_output_pb2.DriverOutput()),
            (metadata.Metadata, '__init__', None),
            (metadata.Metadata, '__exit__', None),
            (launcher.Launcher, '_publish_successful_execution', None),
            (launcher.Launcher, '_clean_up_stateless_execution_info', None),
            (launcher.Launcher, '_clean_up_stateful_execution_info', None),
            (outputs_utils, 'OutputsResolver', mock.MagicMock()),
            (execution_lib, 'get_executions_associated_with_all_contexts', []),
            (container_entrypoint, '_dump_ui_metadata', None),
        )
        for cls, method, return_value in mock_targets:
            self.enter_context(
                mock.patch.object(cls,
                                  method,
                                  autospec=True,
                                  return_value=return_value))

        mock_mlmd = self.enter_context(
            mock.patch.object(metadata.Metadata, '__enter__',
                              autospec=True)).return_value
        mock_mlmd.store.return_value.get_executions_by_id.return_value = [
            metadata_store_pb2.Execution()
        ]

        self._set_required_env_vars({
            'WORKFLOW_ID':
            'workflow-id-42',
            'METADATA_GRPC_SERVICE_HOST':
            'metadata-grpc',
            'METADATA_GRPC_SERVICE_PORT':
            '8080',
            container_entrypoint._KFP_POD_NAME_ENV_KEY:
            'test_pod_name'
        })

        mock_register_execution = self.enter_context(
            mock.patch.object(execution_publish_utils,
                              'register_execution',
                              autospec=True))

        test_ir_file = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), 'testdata',
            'two_step_pipeline_post_dehydrate_ir.json')
        test_ir = io_utils.read_string_file(test_ir_file)

        argv = [
            '--pipeline_root',
            'dummy',
            '--kubeflow_metadata_config',
            json_format.MessageToJson(
                kubeflow_dag_runner.get_default_kubeflow_metadata_config()),
            '--tfx_ir',
            test_ir,
            '--node_id',
            'BigQueryExampleGen',
            '--runtime_parameter',
            'pipeline-run-id=STRING:my-run-id',
        ]
        container_entrypoint.main(argv)

        mock_register_execution.assert_called_once()
        kwargs = mock_register_execution.call_args[1]
        self.assertEqual(
            kwargs['exec_properties']
            [container_entrypoint._KFP_POD_NAME_PROPERTY_KEY], 'test_pod_name')