def _compile_node_outputs(self, tfx_node: base_node.BaseNode, node_pb: pipeline_pb2.PipelineNode): """Compiles the outputs of a node/component.""" for key, value in tfx_node.outputs.items(): output_spec = node_pb.outputs.outputs[key] artifact_type = value.type._get_artifact_type() # pylint: disable=protected-access output_spec.artifact_spec.type.CopyFrom(artifact_type) # Attach additional properties for artifacts produced by importer nodes. for property_name, property_value in value.additional_properties.items(): _check_property_value_type(property_name, property_value, artifact_type) value_field = output_spec.artifact_spec.additional_properties[ property_name].field_value try: data_types_utils.set_metadata_value(value_field, property_value) except ValueError: raise ValueError( "Component {} got unsupported parameter {} with type {}.".format( tfx_node.id, property_name, type(property_value))) from ValueError for property_name, property_value in ( value.additional_custom_properties.items()): value_field = output_spec.artifact_spec.additional_custom_properties[ property_name].field_value try: data_types_utils.set_metadata_value(value_field, property_value) except ValueError: raise ValueError( "Component {} got unsupported parameter {} with type {}.".format( tfx_node.id, property_name, type(property_value))) from ValueError
def prepare_execution( metadata_handler: metadata.Metadata, execution_type: metadata_store_pb2.ExecutionType, state: metadata_store_pb2.Execution.State, exec_properties: Optional[Mapping[Text, types.Property]] = None, ) -> metadata_store_pb2.Execution: """Creates an execution proto based on the information provided. Args: metadata_handler: A handler to access MLMD store. execution_type: A metadata_pb2.ExecutionType message describing the type of the execution. state: The state of the execution. exec_properties: Execution properties that need to be attached. Returns: A metadata_store_pb2.Execution message. """ execution = metadata_store_pb2.Execution() execution.last_known_state = state execution.type_id = common_utils.register_type_if_not_exist( metadata_handler, execution_type).id exec_properties = exec_properties or {} # For every execution property, put it in execution.properties if its key is # in execution type schema. Otherwise, put it in execution.custom_properties. for k, v in exec_properties.items(): if (execution_type.properties.get(k) == data_types_utils.get_metadata_value_type(v)): data_types_utils.set_metadata_value(execution.properties[k], v) else: data_types_utils.set_metadata_value(execution.custom_properties[k], v) logging.debug('Prepared EXECUTION:\n %s', execution) return execution
def substitute_runtime_parameter( msg: message.Message, parameter_bindings: Mapping[str, types.Property]) -> None: """Utility function to substitute runtime parameter placeholders with values. Args: msg: The original message to change. Only messages defined under pipeline_pb2 will be supported. Other types will result in no-op. parameter_bindings: A dict of parameter keys to parameter values that will be used to substitute the runtime parameter placeholder. Returns: None """ if not isinstance(msg, message.Message): return # If the message is a pipeline_pb2.Value instance, try to find an substitute # with runtime parameter bindings. if isinstance(msg, pipeline_pb2.Value): value = cast(pipeline_pb2.Value, msg) which = value.WhichOneof('value') if which == 'runtime_parameter': real_value = _get_runtime_parameter_value(value.runtime_parameter, parameter_bindings) if real_value is None: return value.Clear() data_types_utils.set_metadata_value( metadata_value=value.field_value, value=real_value) if which == 'structural_runtime_parameter': real_value = _get_structural_runtime_parameter_value( value.structural_runtime_parameter, parameter_bindings) if real_value is None: return value.Clear() data_types_utils.set_metadata_value( metadata_value=value.field_value, value=real_value) return # For other cases, recursively call into sub-messages if any. for field, sub_message in msg.ListFields(): # No-op for non-message types. if field.type != descriptor.FieldDescriptor.TYPE_MESSAGE: continue # Evaluates every map values in a map. elif (field.message_type.has_options and field.message_type.GetOptions().map_entry): for key in sub_message: substitute_runtime_parameter(sub_message[key], parameter_bindings) # Evaluates every entry in a list. elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: for element in sub_message: substitute_runtime_parameter(element, parameter_bindings) # Evaluates sub-message. else: substitute_runtime_parameter(sub_message, parameter_bindings)
def test_pipeline_failure_strategies(self, fail_fast): """Tests pipeline failure strategies.""" test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) self._run_next(False, expect_nodes=[self._stats_gen], fail_fast=fail_fast) self._run_next(False, expect_nodes=[self._schema_gen], fail_fast=fail_fast) # Both example-validator and transform are ready to execute. [example_validator_task, transform_task] = self._generate(False, True, fail_fast=fail_fast) self.assertEqual(self._example_validator.node_info.id, example_validator_task.node_uid.node_id) self.assertEqual(self._transform.node_info.id, transform_task.node_uid.node_id) # Simulate Transform success. self._finish_node_execution(False, transform_task) # But fail example-validator. with self._mlmd_connection as m: with mlmd_state.mlmd_execution_atomic_op( m, example_validator_task.execution_id) as ev_exec: # Fail stats-gen execution. ev_exec.last_known_state = metadata_store_pb2.Execution.FAILED data_types_utils.set_metadata_value( ev_exec.custom_properties[ constants.EXECUTION_ERROR_MSG_KEY], 'example-validator error') if fail_fast: # Pipeline run should immediately fail because example-validator failed. [finalize_task] = self._generate(False, True, fail_fast=fail_fast) self.assertTrue(task_lib.is_finalize_pipeline_task(finalize_task)) self.assertEqual(status_lib.Code.ABORTED, finalize_task.status.code) else: # Trainer and downstream nodes can execute as transform has finished. # example-validator failure does not impact them as it is not upstream. # Pipeline run will still fail but when no more progress can be made. self._run_next(False, expect_nodes=[self._trainer], fail_fast=fail_fast) self._run_next(False, expect_nodes=[self._chore_a], fail_fast=fail_fast) self._run_next(False, expect_nodes=[self._chore_b], fail_fast=fail_fast) [finalize_task] = self._generate(False, True, fail_fast=fail_fast) self.assertTrue(task_lib.is_finalize_pipeline_task(finalize_task)) self.assertEqual(status_lib.Code.ABORTED, finalize_task.status.code)
def _update_execution_state_in_mlmd( mlmd_handle: metadata.Metadata, execution: metadata_store_pb2.Execution, new_state: metadata_store_pb2.Execution.State, error_msg: str) -> None: updated_execution = copy.deepcopy(execution) updated_execution.last_known_state = new_state if error_msg: data_types_utils.set_metadata_value( updated_execution.custom_properties[constants.EXECUTION_ERROR_MSG_KEY], error_msg) mlmd_handle.store.put_executions([updated_execution])
def _update_execution_state_in_mlmd( mlmd_handle: metadata.Metadata, execution_id: int, new_state: metadata_store_pb2.Execution.State, error_msg: str) -> None: with mlmd_state.mlmd_execution_atomic_op(mlmd_handle, execution_id) as execution: execution.last_known_state = new_state if error_msg: data_types_utils.set_metadata_value( execution.custom_properties[constants.EXECUTION_ERROR_MSG_KEY], error_msg)
def testSetMetadataValueWithTfxValue(self): tfx_value = pipeline_pb2.Value() metadata_property = metadata_store_pb2.Value() text_format.Parse( """ field_value { int_value: 1 }""", tfx_value) data_types_utils.set_metadata_value(metadata_value=metadata_property, value=tfx_value) self.assertProtoEquals('int_value: 1', metadata_property)
def testSetMetadataValueWithTfxValueFailed(self): tfx_value = pipeline_pb2.Value() metadata_property = metadata_store_pb2.Value() text_format.Parse( """ runtime_parameter { name: 'rp' }""", tfx_value) with self.assertRaisesRegex(ValueError, 'Expecting field_value but got'): data_types_utils.set_metadata_value( metadata_value=metadata_property, value=tfx_value)
def initiate_update( self, updated_pipeline: pipeline_pb2.Pipeline, update_options: pipeline_pb2.UpdateOptions, ) -> None: """Initiates pipeline update process.""" self._check_context() if self.pipeline.execution_mode != updated_pipeline.execution_mode: raise status_lib.StatusNotOkError( code=status_lib.Code.INVALID_ARGUMENT, message=( 'Updating execution_mode of an active pipeline is not ' 'supported')) if self.pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: updated_pipeline_run_id = ( updated_pipeline.runtime_spec.pipeline_run_id.field_value. string_value) if self.pipeline_run_id != updated_pipeline_run_id: raise status_lib.StatusNotOkError( code=status_lib.Code.INVALID_ARGUMENT, message= (f'For sync pipeline, pipeline_run_id should match; found ' f'mismatch: {self.pipeline_run_id} (existing) vs. ' f'{updated_pipeline_run_id} (updated)')) # TODO(b/194311197): We require that structure of the updated pipeline # exactly matches the original. There is scope to relax this restriction. def _structure( pipeline: pipeline_pb2.Pipeline ) -> List[Tuple[str, List[str], List[str]]]: return [(node.node_info.id, list(node.upstream_nodes), list(node.downstream_nodes)) for node in get_all_pipeline_nodes(pipeline)] if _structure(self.pipeline) != _structure(updated_pipeline): raise status_lib.StatusNotOkError( code=status_lib.Code.INVALID_ARGUMENT, message=( 'Updated pipeline should have the same structure as the ' 'original.')) data_types_utils.set_metadata_value( self._execution.custom_properties[_UPDATED_PIPELINE_IR], _base64_encode(updated_pipeline)) data_types_utils.set_metadata_value( self._execution.custom_properties[_UPDATE_OPTIONS], _base64_encode(update_options))
def new(cls, mlmd_handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline) -> 'PipelineState': """Creates a `PipelineState` object for a new pipeline. No active pipeline with the same pipeline uid should exist for the call to be successful. Args: mlmd_handle: A handle to the MLMD db. pipeline: IR of the pipeline. Returns: A `PipelineState` object. Raises: status_lib.StatusNotOkError: If a pipeline with same UID already exists. """ pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) context = context_lib.register_context_if_not_exists( mlmd_handle, context_type_name=_ORCHESTRATOR_RESERVED_ID, context_name=orchestrator_context_name(pipeline_uid)) executions = mlmd_handle.store.get_executions_by_context(context.id) if any(e for e in executions if execution_lib.is_execution_active(e)): raise status_lib.StatusNotOkError( code=status_lib.Code.ALREADY_EXISTS, message=f'Pipeline with uid {pipeline_uid} already active.') execution = execution_lib.prepare_execution( mlmd_handle, _ORCHESTRATOR_EXECUTION_TYPE, metadata_store_pb2.Execution.NEW, exec_properties={ _PIPELINE_IR: base64.b64encode(pipeline.SerializeToString()).decode('utf-8') }, ) if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: data_types_utils.set_metadata_value( execution.custom_properties[_PIPELINE_RUN_ID], pipeline.runtime_spec.pipeline_run_id.field_value.string_value) execution = execution_lib.put_execution(mlmd_handle, execution, [context]) record_state_change_time() return cls(mlmd_handle=mlmd_handle, pipeline=pipeline, execution_id=execution.id)
def apply_pipeline_update(self) -> None: """Applies pipeline update that was previously initiated.""" self._check_context() updated_pipeline_ir = _get_metadata_value( self._execution.custom_properties.get(_UPDATED_PIPELINE_IR)) if not updated_pipeline_ir: raise status_lib.StatusNotOkError( code=status_lib.Code.INVALID_ARGUMENT, message='No updated pipeline IR to apply') data_types_utils.set_metadata_value( self._execution.properties[_PIPELINE_IR], updated_pipeline_ir) del self._execution.custom_properties[_UPDATED_PIPELINE_IR] del self._execution.custom_properties[_UPDATE_OPTIONS] self.pipeline = _base64_decode_pipeline(updated_pipeline_ir)
def test_restart_node_cancelled_due_to_stopping(self): """Tests that a node previously cancelled due to stopping can be restarted.""" test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) [stats_gen_task ] = self._generate_and_test(False, num_initial_executions=1, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, ignore_update_node_state_tasks=True) node_uid = task_lib.NodeUid.from_pipeline_node(self._pipeline, self._stats_gen) self.assertEqual(node_uid, stats_gen_task.node_uid) # Simulate stopping the node while it is under execution, which leads to # the node execution being cancelled. with self._mlmd_connection as m: with mlmd_state.mlmd_execution_atomic_op( m, stats_gen_task.execution_id) as stats_gen_exec: stats_gen_exec.last_known_state = metadata_store_pb2.Execution.CANCELED data_types_utils.set_metadata_value( stats_gen_exec.custom_properties[ constants.EXECUTION_ERROR_MSG_KEY], 'manually stopped') # Change state of node to STARTING. with self._mlmd_connection as m: pipeline_state = test_utils.get_or_create_pipeline_state( m, self._pipeline) with pipeline_state: with pipeline_state.node_state_update_context( node_uid) as node_state: node_state.update(pstate.NodeState.STARTING) # New execution should be created for any previously canceled node when the # node state is STARTING. [update_node_state_task, stats_gen_task] = self._generate_and_test(False, num_initial_executions=2, num_tasks_generated=2, num_new_executions=1, num_active_executions=1) self.assertTrue( task_lib.is_update_node_state_task(update_node_state_task)) self.assertEqual(node_uid, update_node_state_task.node_uid) self.assertEqual(pstate.NodeState.RUNNING, update_node_state_task.state) self.assertEqual(node_uid, stats_gen_task.node_uid)
def initiate_node_stop(self, node_uid: task_lib.NodeUid) -> None: """Updates pipeline state to signal that a node should be stopped.""" if self.pipeline.execution_mode != pipeline_pb2.Pipeline.ASYNC: raise status_lib.StatusNotOkError( code=status_lib.Code.UNIMPLEMENTED, message='Node can be started only for async pipelines.') if not _is_node_uid_in_pipeline(node_uid, self.pipeline): raise status_lib.StatusNotOkError( code=status_lib.Code.INVALID_ARGUMENT, message=( f'Node given by uid {node_uid} does not belong to pipeline ' f'given by uid {self.pipeline_uid}')) data_types_utils.set_metadata_value( self.execution.custom_properties[_node_stop_initiated_property( node_uid)], 1) self._commit = True
def test_node_failed(self, use_task_queue): """Tests task generation when a node registers a failed execution.""" otu.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) def _ensure_node_services(unused_pipeline_state, node_id): self.assertEqual(self._example_gen.node_info.id, node_id) return service_jobs.ServiceStatus.SUCCESS self._mock_service_job_manager.ensure_node_services.side_effect = ( _ensure_node_services) tasks, active_executions = self._generate_and_test( use_task_queue, num_initial_executions=1, num_tasks_generated=1, num_new_executions=1, num_active_executions=1) self.assertEqual( task_lib.NodeUid.from_pipeline_node(self._pipeline, self._stats_gen), tasks[0].node_uid) stats_gen_exec = active_executions[0] # Fail stats-gen execution. stats_gen_exec.last_known_state = metadata_store_pb2.Execution.FAILED data_types_utils.set_metadata_value( stats_gen_exec.custom_properties[constants.EXECUTION_ERROR_MSG_KEY], 'foobar error') with self._mlmd_connection as m: m.store.put_executions([stats_gen_exec]) if use_task_queue: task = self._task_queue.dequeue() self._task_queue.task_done(task) # Test generation of FinalizePipelineTask. tasks, _ = self._generate_and_test( True, num_initial_executions=2, num_tasks_generated=1, num_new_executions=0, num_active_executions=0) self.assertLen(tasks, 1) self.assertTrue(task_lib.is_finalize_pipeline_task(tasks[0])) self.assertEqual(status_lib.Code.ABORTED, tasks[0].status.code) self.assertRegexMatch(tasks[0].status.message, ['foobar error'])
def test_node_failed(self, fail_fast): """Tests task generation when a node registers a failed execution.""" test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) [stats_gen_task ] = self._generate_and_test(False, num_initial_executions=1, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, ignore_update_node_state_tasks=True, fail_fast=fail_fast) self.assertEqual( task_lib.NodeUid.from_pipeline_node(self._pipeline, self._stats_gen), stats_gen_task.node_uid) with self._mlmd_connection as m: with mlmd_state.mlmd_execution_atomic_op( m, stats_gen_task.execution_id) as stats_gen_exec: # Fail stats-gen execution. stats_gen_exec.last_known_state = metadata_store_pb2.Execution.FAILED data_types_utils.set_metadata_value( stats_gen_exec.custom_properties[ constants.EXECUTION_ERROR_MSG_KEY], 'foobar error') # Test generation of FinalizePipelineTask. [update_node_state_task, finalize_task] = self._generate_and_test(True, num_initial_executions=2, num_tasks_generated=2, num_new_executions=0, num_active_executions=0, fail_fast=fail_fast) self.assertTrue( task_lib.is_update_node_state_task(update_node_state_task)) self.assertEqual('my_statistics_gen', update_node_state_task.node_uid.node_id) self.assertEqual(pstate.NodeState.FAILED, update_node_state_task.state) self.assertRegexMatch(update_node_state_task.status.message, ['foobar error']) self.assertTrue(task_lib.is_finalize_pipeline_task(finalize_task)) self.assertEqual(status_lib.Code.ABORTED, finalize_task.status.code) self.assertRegexMatch(finalize_task.status.message, ['foobar error'])
def prepare_execution( metadata_handler: metadata.Metadata, execution_type: metadata_store_pb2.ExecutionType, state: metadata_store_pb2.Execution.State, exec_properties: Optional[Mapping[str, types.ExecPropertyTypes]] = None, ) -> metadata_store_pb2.Execution: """Creates an execution proto based on the information provided. Args: metadata_handler: A handler to access MLMD store. execution_type: A metadata_pb2.ExecutionType message describing the type of the execution. state: The state of the execution. exec_properties: Execution properties that need to be attached. Returns: A metadata_store_pb2.Execution message. """ execution = metadata_store_pb2.Execution() execution.last_known_state = state execution.type_id = common_utils.register_type_if_not_exist( metadata_handler, execution_type).id exec_properties = exec_properties or {} # For every execution property, put it in execution.properties if its key is # in execution type schema. Otherwise, put it in execution.custom_properties. for k, v in exec_properties.items(): value = pipeline_pb2.Value() value = data_types_utils.set_parameter_value(value, v) if value.HasField('schema'): # Stores schema in custom_properties for non-primitive types to allow # parsing in later stages. data_types_utils.set_metadata_value( execution.custom_properties[get_schema_key(k)], proto_utils.proto_to_json(value.schema)) if (execution_type.properties.get(k) == data_types_utils.get_metadata_value_type(v)): execution.properties[k].CopyFrom(value.field_value) else: execution.custom_properties[k].CopyFrom(value.field_value) logging.debug('Prepared EXECUTION:\n %s', execution) return execution
def run( self, execution_info: portable_data_types.ExecutionInfo ) -> driver_output_pb2.DriverOutput: # Populate exec_properties result = driver_output_pb2.DriverOutput() # PipelineInfo and ComponentInfo are not actually used, two fake one are # created just to be compatible with the old API. pipeline_info = data_types.PipelineInfo('', '') component_info = data_types.ComponentInfo('', '', pipeline_info) exec_properties = self.resolve_exec_properties( execution_info.exec_properties, pipeline_info, component_info) for k, v in exec_properties.items(): if v is not None: data_types_utils.set_metadata_value(result.exec_properties[k], v) # Populate output_dict output_example = copy.deepcopy(execution_info.output_dict[ standard_component_specs.EXAMPLES_KEY][0].mlmd_artifact) update_output_artifact(exec_properties, output_example) result.output_artifacts[standard_component_specs. EXAMPLES_KEY].artifacts.append(output_example) return result
def _generate_context_proto( metadata_handler: metadata.Metadata, context_spec: pipeline_pb2.ContextSpec) -> metadata_store_pb2.Context: """Generates metadata_pb2.Context based on the ContextSpec message. Args: metadata_handler: A handler to access MLMD store. context_spec: A pipeline_pb2.ContextSpec message that instructs registering of a context. Returns: A metadata_store_pb2.Context message. Raises: RuntimeError: When actual property type does not match provided metadata type schema. """ context_type = common_utils.register_type_if_not_exist( metadata_handler, context_spec.type) context_name = data_types_utils.get_value(context_spec.name) assert isinstance(context_name, Text), 'context name should be string.' result = metadata_store_pb2.Context(type_id=context_type.id, name=context_name) for k, v in context_spec.properties.items(): if k in context_type.properties: actual_property_type = data_types_utils.get_metadata_value_type(v) if context_type.properties.get(k) == actual_property_type: data_types_utils.set_metadata_value(result.properties[k], v) else: raise RuntimeError( 'Property type %s different from provided metadata type property type %s for key %s' % (actual_property_type, context_type.properties.get(k), k)) else: data_types_utils.set_metadata_value(result.custom_properties[k], v) return result
def initiate_stop(self, status: status_lib.Status) -> None: """Updates pipeline state to signal stopping pipeline execution.""" self._check_context() data_types_utils.set_metadata_value( self._execution.custom_properties[_STOP_INITIATED], 1) data_types_utils.set_metadata_value( self._execution.custom_properties[_PIPELINE_STATUS_CODE], int(status.code)) if status.message: data_types_utils.set_metadata_value( self._execution.custom_properties[_PIPELINE_STATUS_MSG], status.message)
def _compile_node( self, tfx_node: base_node.BaseNode, compile_context: _CompilerContext, deployment_config: pipeline_pb2.IntermediateDeploymentConfig, enable_cache: bool, ) -> pipeline_pb2.PipelineNode: """Compiles an individual TFX node into a PipelineNode proto. Args: tfx_node: A TFX node. compile_context: Resources needed to compile the node. deployment_config: Intermediate deployment config to set. Will include related specs for executors, drivers and platform specific configs. enable_cache: whether cache is enabled Raises: TypeError: When supplied tfx_node has values of invalid type. Returns: A PipelineNode proto that encodes information of the node. """ node = pipeline_pb2.PipelineNode() # Step 1: Node info node.node_info.type.name = tfx_node.type node.node_info.id = tfx_node.id # Step 2: Node Context # Context for the pipeline, across pipeline runs. pipeline_context_pb = node.contexts.contexts.add() pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name # Context for the current pipeline run. if compile_context.is_sync_mode: pipeline_run_context_pb = node.contexts.contexts.add() pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME compiler_utils.set_runtime_parameter_pb( pipeline_run_context_pb.name.runtime_parameter, constants.PIPELINE_RUN_ID_PARAMETER_NAME, str) # Context for the node, across pipeline runs. node_context_pb = node.contexts.contexts.add() node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME node_context_pb.name.field_value.string_value = "{}.{}".format( compile_context.pipeline_info.pipeline_context_name, node.node_info.id) # Pre Step 3: Alter graph topology if needed. if compile_context.is_async_mode: tfx_node_inputs = self._compile_resolver_config( compile_context, tfx_node, node) else: tfx_node_inputs = tfx_node.inputs # Step 3: Node inputs for key, value in tfx_node_inputs.items(): input_spec = node.inputs.inputs[key] channel = input_spec.channels.add() if value.producer_component_id: channel.producer_node_query.id = value.producer_component_id # Here we rely on pipeline.components to be topologically sorted. assert value.producer_component_id in compile_context.node_pbs, ( "producer component should have already been compiled.") producer_pb = compile_context.node_pbs[ value.producer_component_id] for producer_context in producer_pb.contexts.contexts: if (not compiler_utils.is_resolver(tfx_node) or producer_context.name.runtime_parameter.name != constants.PIPELINE_RUN_CONTEXT_TYPE_NAME): context_query = channel.context_queries.add() context_query.type.CopyFrom(producer_context.type) context_query.name.CopyFrom(producer_context.name) else: # Caveat: portable core requires every channel to have at least one # Contex. But For cases like system nodes and producer-consumer # pipelines, a channel may not have contexts at all. In these cases, # we want to use the pipeline level context as the input channel # context. context_query = channel.context_queries.add() context_query.type.CopyFrom(pipeline_context_pb.type) context_query.name.CopyFrom(pipeline_context_pb.name) artifact_type = value.type._get_artifact_type() # pylint: disable=protected-access channel.artifact_query.type.CopyFrom(artifact_type) channel.artifact_query.type.ClearField("properties") if value.output_key: channel.output_key = value.output_key # TODO(b/158712886): Calculate min_count based on if inputs are optional. # min_count = 0 stands for optional input and 1 stands for required input. # Step 3.1: Special treatment for Resolver node. if compiler_utils.is_resolver(tfx_node): assert compile_context.is_sync_mode node.inputs.resolver_config.resolver_steps.extend( _convert_to_resolver_steps(tfx_node)) # Step 4: Node outputs if isinstance(tfx_node, base_component.BaseComponent): for key, value in tfx_node.outputs.items(): output_spec = node.outputs.outputs[key] artifact_type = value.type._get_artifact_type() # pylint: disable=protected-access output_spec.artifact_spec.type.CopyFrom(artifact_type) for prop_key, prop_value in value.additional_properties.items( ): _check_property_value_type(prop_key, prop_value, output_spec.artifact_spec.type) data_types_utils.set_metadata_value( output_spec.artifact_spec. additional_properties[prop_key].field_value, prop_value) for prop_key, prop_value in value.additional_custom_properties.items( ): data_types_utils.set_metadata_value( output_spec.artifact_spec. additional_custom_properties[prop_key].field_value, prop_value) # TODO(b/170694459): Refactor special nodes as plugins. # Step 4.1: Special treament for Importer node if compiler_utils.is_importer(tfx_node): self._compile_importer_node_outputs(tfx_node, node) # Step 5: Node parameters if not compiler_utils.is_resolver(tfx_node): for key, value in tfx_node.exec_properties.items(): if value is None: continue # Ignore following two properties for a importer node, because they are # already attached to the artifacts produced by the importer node. if compiler_utils.is_importer(tfx_node) and ( key == importer.PROPERTIES_KEY or key == importer.CUSTOM_PROPERTIES_KEY): continue parameter_value = node.parameters.parameters[key] # Order matters, because runtime parameter can be in serialized string. if isinstance(value, data_types.RuntimeParameter): compiler_utils.set_runtime_parameter_pb( parameter_value.runtime_parameter, value.name, value.ptype, value.default) elif isinstance(value, str) and re.search( data_types.RUNTIME_PARAMETER_PATTERN, value): runtime_param = json.loads(value) compiler_utils.set_runtime_parameter_pb( parameter_value.runtime_parameter, runtime_param.name, runtime_param.ptype, runtime_param.default) else: try: data_types_utils.set_metadata_value( parameter_value.field_value, value) except ValueError: raise ValueError( "Component {} got unsupported parameter {} with type {}." .format(tfx_node.id, key, type(value))) # Step 6: Executor spec and optional driver spec for components if isinstance(tfx_node, base_component.BaseComponent): executor_spec = tfx_node.executor_spec.encode( component_spec=tfx_node.spec) deployment_config.executor_specs[tfx_node.id].Pack(executor_spec) # TODO(b/163433174): Remove specialized logic once generalization of # driver spec is done. if tfx_node.driver_class != base_driver.BaseDriver: driver_class_path = "{}.{}".format( tfx_node.driver_class.__module__, tfx_node.driver_class.__name__) driver_spec = executable_spec_pb2.PythonClassExecutableSpec() driver_spec.class_path = driver_class_path deployment_config.custom_driver_specs[tfx_node.id].Pack( driver_spec) # Step 7: Upstream/Downstream nodes # Note: the order of tfx_node.upstream_nodes is inconsistent from # run to run. We sort them so that compiler generates consistent results. # For ASYNC mode upstream/downstream node information is not set as # compiled IR graph topology can be different from that on pipeline # authoring time; for example ResolverNode is removed. if compile_context.is_sync_mode: node.upstream_nodes.extend( sorted(node.id for node in tfx_node.upstream_nodes)) node.downstream_nodes.extend( sorted(node.id for node in tfx_node.downstream_nodes)) # Step 8: Node execution options node.execution_options.caching_options.enable_cache = enable_cache # Step 9: Per-node platform config if isinstance(tfx_node, base_component.BaseComponent): tfx_component = cast(base_component.BaseComponent, tfx_node) if tfx_component.platform_config: deployment_config.node_level_platform_configs[ tfx_node.id].Pack(tfx_component.platform_config) return node
def testSetMetadataValueUnsupportedType(self): pb = metadata_store_pb2.Value() with self.assertRaises(ValueError): data_types_utils.set_metadata_value(pb, {'a': 1})
def testSetMetadataValueWithPrimitiveValue(self, value, expected_pb): pb = metadata_store_pb2.Value() data_types_utils.set_metadata_value(pb, value) self.assertEqual(pb, expected_pb)
def set_mlmd_value( self, value: metadata_store_pb2.Value) -> metadata_store_pb2.Value: data_types_utils.set_metadata_value(value, json_utils.dumps(self)) return value
def new( cls, mlmd_handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline, pipeline_run_metadata: Optional[Mapping[str, types.Property]] = None, ) -> 'PipelineState': """Creates a `PipelineState` object for a new pipeline. No active pipeline with the same pipeline uid should exist for the call to be successful. Args: mlmd_handle: A handle to the MLMD db. pipeline: IR of the pipeline. pipeline_run_metadata: Pipeline run metadata. Returns: A `PipelineState` object. Raises: status_lib.StatusNotOkError: If a pipeline with same UID already exists. """ pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) context = context_lib.register_context_if_not_exists( mlmd_handle, context_type_name=_ORCHESTRATOR_RESERVED_ID, context_name=orchestrator_context_name(pipeline_uid)) executions = mlmd_handle.store.get_executions_by_context(context.id) if any(e for e in executions if execution_lib.is_execution_active(e)): raise status_lib.StatusNotOkError( code=status_lib.Code.ALREADY_EXISTS, message=f'Pipeline with uid {pipeline_uid} already active.') exec_properties = {_PIPELINE_IR: _base64_encode(pipeline)} if pipeline_run_metadata: exec_properties[_PIPELINE_RUN_METADATA] = json_utils.dumps( pipeline_run_metadata) execution = execution_lib.prepare_execution( mlmd_handle, _ORCHESTRATOR_EXECUTION_TYPE, metadata_store_pb2.Execution.NEW, exec_properties=exec_properties) if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: data_types_utils.set_metadata_value( execution.custom_properties[_PIPELINE_RUN_ID], pipeline.runtime_spec.pipeline_run_id.field_value.string_value) # Set the node state to COMPLETE for any nodes that are marked to be # skipped in a partial pipeline run. node_states_dict = {} for node in get_all_pipeline_nodes(pipeline): if node.execution_options.HasField('skip'): logging.info('Node %s is skipped in this partial run.', node.node_info.id) node_states_dict[node.node_info.id] = NodeState( state=NodeState.COMPLETE) if node_states_dict: _save_node_states_dict(execution, node_states_dict) execution = execution_lib.put_execution(mlmd_handle, execution, [context]) record_state_change_time() return cls(mlmd_handle=mlmd_handle, pipeline=pipeline, execution_id=execution.id)
def initiate_stop(self): """Updates pipeline state to signal stopping pipeline execution.""" data_types_utils.set_metadata_value( self.execution.custom_properties[_STOP_INITIATED], 1) self._commit = True
def _save_node_states_dict(pipeline_execution: metadata_store_pb2.Execution, node_states: Dict[str, NodeState]) -> None: data_types_utils.set_metadata_value( pipeline_execution.custom_properties[_NODE_STATES], json_utils.dumps(node_states))