def run(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> driver_output_pb2.DriverOutput: # Fake a constant span number, which, on prod, is usually calculated based # on date. span = 2 with self._mlmd_connection as m: previous_output = inputs_utils.resolve_input_artifacts( m, self._self_output) # Version should be the max of existing version + 1 if span exists, # otherwise 0. version = 0 if previous_output: version = max([ artifact.get_int_custom_property('version') for artifact in previous_output['examples'] if artifact.get_int_custom_property('span') == span ] or [-1]) + 1 output_example = copy.deepcopy( output_dict['output_examples'][0].mlmd_artifact) output_example.custom_properties['span'].int_value = span output_example.custom_properties['version'].int_value = version result = driver_output_pb2.DriverOutput() result.output_artifacts['output_examples'].artifacts.append( output_example) result.exec_properties['span'].int_value = span result.exec_properties['version'].int_value = version return result
def run(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> driver_output_pb2.DriverOutput: # Populate exec_properties result = driver_output_pb2.DriverOutput() # PipelineInfo and ComponentInfo are not actually used, two fake one are # created just to be compatable with the old API. pipeline_info = data_types.PipelineInfo('', '') component_info = data_types.ComponentInfo('', '', pipeline_info) exec_properties = self.resolve_exec_properties(exec_properties, pipeline_info, component_info) for k, v in exec_properties.items(): if v is not None: common_utils.set_metadata_value(result.exec_properties[k], v) # Populate output_dict output_example = copy.deepcopy( output_dict[utils.EXAMPLES_KEY][0].mlmd_artifact) _update_output_artifact(exec_properties, output_example) result.output_artifacts[utils.EXAMPLES_KEY].artifacts.append( output_example) return result
def run( self, execution_info: portable_data_types.ExecutionInfo ) -> driver_output_pb2.DriverOutput: # Populate exec_properties result = driver_output_pb2.DriverOutput() # PipelineInfo and ComponentInfo are not actually used, two fake one are # created just to be compatible with the old API. pipeline_info = data_types.PipelineInfo('', '') component_info = data_types.ComponentInfo('', '', pipeline_info) exec_properties = self.resolve_exec_properties( execution_info.exec_properties, pipeline_info, component_info) for k, v in exec_properties.items(): if v is not None: data_types_utils.set_metadata_value(result.exec_properties[k], v) # Populate output_dict output_example = copy.deepcopy(execution_info.output_dict[ standard_component_specs.EXAMPLES_KEY][0].mlmd_artifact) update_output_artifact(exec_properties, output_example) result.output_artifacts[standard_component_specs. EXAMPLES_KEY].artifacts.append(output_example) return result
# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for tfx.orchestration.portable.python_driver_operator.""" from typing import Any, Dict, List, Text import tensorflow as tf from tfx import types from tfx.orchestration.portable import base_driver from tfx.orchestration.portable import python_driver_operator from tfx.proto.orchestration import driver_output_pb2 from tfx.proto.orchestration import executable_spec_pb2 _DEFAULT_DRIVER_OUTPUT = driver_output_pb2.DriverOutput() class _FakeNoopDriver(base_driver.BaseDriver): def run(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> driver_output_pb2.DriverOutput: return _DEFAULT_DRIVER_OUTPUT class PythonDriverOperatorTest(tf.test.TestCase): def succeed(self): custom_driver_spec = (executable_spec_pb2.PythonClassExecutableSpec()) custom_driver_spec.class_path = 'tfx.orchestration.portable.python_driver_operator._FakeNoopDriver'
def testOverrideRegisterExecution(self): # Mock all real operations of driver / executor / MLMD accesses. mock_targets = ( # (cls, method, return_value) (beam_executor_operator.BeamExecutorOperator, '__init__', None), (beam_executor_operator.BeamExecutorOperator, 'run_executor', execution_result_pb2.ExecutorOutput()), (python_driver_operator.PythonDriverOperator, '__init__', None), (python_driver_operator.PythonDriverOperator, 'run_driver', driver_output_pb2.DriverOutput()), (metadata.Metadata, '__init__', None), (metadata.Metadata, '__exit__', None), (launcher.Launcher, '_publish_successful_execution', None), (launcher.Launcher, '_clean_up_stateless_execution_info', None), (launcher.Launcher, '_clean_up_stateful_execution_info', None), (outputs_utils, 'OutputsResolver', mock.MagicMock()), (execution_lib, 'get_executions_associated_with_all_contexts', []), (container_entrypoint, '_dump_ui_metadata', None), ) for cls, method, return_value in mock_targets: self.enter_context( mock.patch.object(cls, method, autospec=True, return_value=return_value)) mock_mlmd = self.enter_context( mock.patch.object(metadata.Metadata, '__enter__', autospec=True)).return_value mock_mlmd.store.return_value.get_executions_by_id.return_value = [ metadata_store_pb2.Execution() ] self._set_required_env_vars({ 'WORKFLOW_ID': 'workflow-id-42', 'METADATA_GRPC_SERVICE_HOST': 'metadata-grpc', 'METADATA_GRPC_SERVICE_PORT': '8080', container_entrypoint._KFP_POD_NAME_ENV_KEY: 'test_pod_name' }) mock_register_execution = self.enter_context( mock.patch.object(execution_publish_utils, 'register_execution', autospec=True)) test_ir_file = os.path.join( os.path.dirname(os.path.abspath(__file__)), 'testdata', 'two_step_pipeline_post_dehydrate_ir.json') test_ir = io_utils.read_string_file(test_ir_file) argv = [ '--pipeline_root', 'dummy', '--kubeflow_metadata_config', json_format.MessageToJson( kubeflow_dag_runner.get_default_kubeflow_metadata_config()), '--tfx_ir', test_ir, '--node_id', 'BigQueryExampleGen', '--runtime_parameter', 'pipeline-run-id=STRING:my-run-id', ] container_entrypoint.main(argv) mock_register_execution.assert_called_once() kwargs = mock_register_execution.call_args[1] self.assertEqual( kwargs['exec_properties'] [container_entrypoint._KFP_POD_NAME_PROPERTY_KEY], 'test_pod_name')