예제 #1
0
 def setUp(self):
     super(PlaceholderUtilsTest, self).setUp()
     examples = [standard_artifacts.Examples()]
     examples[0].uri = "/tmp"
     examples[0].split_names = artifact_utils.encode_split_names(
         ["train", "eval"])
     self._serving_spec = infra_validator_pb2.ServingSpec()
     self._serving_spec.tensorflow_serving.tags.extend(
         ["latest", "1.15.0-gpu"])
     self._resolution_context = placeholder_utils.ResolutionContext(
         exec_info=data_types.ExecutionInfo(
             input_dict={
                 "model": [standard_artifacts.Model()],
                 "examples": examples,
             },
             output_dict={"blessing": [standard_artifacts.ModelBlessing()]},
             exec_properties={
                 "proto_property":
                 json_format.MessageToJson(message=self._serving_spec,
                                           sort_keys=True,
                                           preserving_proto_field_name=True,
                                           indent=0)
             },
             execution_output_uri="test_executor_output_uri",
             stateful_working_dir="test_stateful_working_dir",
             pipeline_node=pipeline_pb2.PipelineNode(
                 node_info=pipeline_pb2.NodeInfo(
                     type=metadata_store_pb2.ExecutionType(
                         name="infra_validator"))),
             pipeline_info=pipeline_pb2.PipelineInfo(
                 id="test_pipeline_id")),
         executor_spec=executable_spec_pb2.PythonClassExecutableSpec(
             class_path="test_class_path"),
     )
     # Resolution context to simulate missing optional values.
     self._none_resolution_context = placeholder_utils.ResolutionContext(
         exec_info=data_types.ExecutionInfo(
             input_dict={
                 "model": [],
                 "examples": [],
             },
             output_dict={"blessing": []},
             exec_properties={},
             pipeline_node=pipeline_pb2.PipelineNode(
                 node_info=pipeline_pb2.NodeInfo(
                     type=metadata_store_pb2.ExecutionType(
                         name="infra_validator"))),
             pipeline_info=pipeline_pb2.PipelineInfo(
                 id="test_pipeline_id")),
         executor_spec=None,
         platform_config=None)
예제 #2
0
    def testDumpUiMetadata(self):
        trainer = pipeline_pb2.PipelineNode()
        trainer.node_info.type.name = 'tfx.components.trainer.component.Trainer'
        model_run_out_spec = pipeline_pb2.OutputSpec(
            artifact_spec=pipeline_pb2.OutputSpec.ArtifactSpec(
                type=metadata_store_pb2.ArtifactType(
                    name=standard_artifacts.ModelRun.TYPE_NAME)))
        trainer.outputs.outputs['model_run'].CopyFrom(model_run_out_spec)

        model_run = standard_artifacts.ModelRun()
        model_run.uri = 'model_run_uri'
        exec_info = data_types.ExecutionInfo(
            input_dict={},
            output_dict={'model_run': [model_run]},
            exec_properties={},
            execution_id='id')
        ui_metadata_path = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName, 'json')
        fileio.makedirs(os.path.dirname(ui_metadata_path))
        container_entrypoint._dump_ui_metadata(trainer, exec_info,
                                               ui_metadata_path)
        with open(ui_metadata_path) as f:
            ui_metadata = json.load(f)
            self.assertEqual('tensorboard', ui_metadata['outputs'][-1]['type'])
            self.assertEqual('model_run_uri',
                             ui_metadata['outputs'][-1]['source'])
예제 #3
0
파일: task_test.py 프로젝트: sycdesign/tfx
 def test_node_uid_from_pipeline_node(self):
   pipeline = pipeline_pb2.Pipeline()
   pipeline.pipeline_info.id = 'pipeline'
   node = pipeline_pb2.PipelineNode()
   node.node_info.id = 'Trainer'
   self.assertEqual(
       task_lib.NodeUid(
           pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline'),
           node_id='Trainer'),
       task_lib.NodeUid.from_pipeline_node(pipeline, node))
예제 #4
0
 def test_node_uid_from_pipeline_node(self):
     pipeline = pipeline_pb2.Pipeline()
     pipeline.pipeline_info.id = 'pipeline'
     pipeline.runtime_spec.pipeline_run_id.field_value.string_value = 'run0'
     node = pipeline_pb2.PipelineNode()
     node.node_info.id = 'Trainer'
     self.assertEqual(
         task_lib.NodeUid(pipeline_uid=task_lib.PipelineUid(
             pipeline_id='pipeline', pipeline_run_id='run0'),
                          node_id='Trainer'),
         task_lib.NodeUid.from_pipeline_node(pipeline, node))
예제 #5
0
 def testGetCacheContextTwiceDifferentNodeInfo(self):
     with metadata.Metadata(connection_config=self._connection_config) as m:
         self._get_cache_context(m)
         self._get_cache_context(m,
                                 custom_pipeline_node=text_format.Parse(
                                     """
           node_info {
             id: "new_node_id"
           }
           """, pipeline_pb2.PipelineNode()))
         # Different executor spec will result in new cache context.
         self.assertLen(m.store.get_contexts(), 2)
예제 #6
0
 def testGetCacheContextTwiceDifferentExecutorSpec(self):
     with metadata.Metadata(connection_config=self._connection_config) as m:
         self._get_cache_context(m)
         self._get_cache_context(m,
                                 custom_pipeline_node=text_format.Parse(
                                     """
           executor {
             python_class_executor_spec {class_path: 'n.e.w'}
           }
           """, pipeline_pb2.PipelineNode()))
         # Different executor spec will result in new cache context.
         self.assertLen(m.store.get_contexts(), 2)
예제 #7
0
파일: task_test.py 프로젝트: kp425/tfx
 def test_exec_node_task_create(self):
     pipeline = pipeline_pb2.Pipeline()
     pipeline.pipeline_info.id = 'pipeline'
     pipeline.runtime_spec.pipeline_run_id.field_value.string_value = 'run0'
     node = pipeline_pb2.PipelineNode()
     node.node_info.id = 'Trainer'
     self.assertEqual(
         task_lib.ExecNodeTask(node_uid=task_lib.NodeUid(
             pipeline_id='pipeline',
             pipeline_run_id='run0',
             node_id='Trainer'),
                               execution_id=123),
         task_lib.ExecNodeTask.create(pipeline, node, 123))
 def _set_up_test_execution_info(self,
                                 input_dict=None,
                                 output_dict=None,
                                 exec_properties=None):
   return data_types.ExecutionInfo(
       input_dict=input_dict or {},
       output_dict=output_dict or {},
       exec_properties=exec_properties or {},
       execution_output_uri='/testing/executor/output/',
       stateful_working_dir='/testing/stateful/dir',
       pipeline_node=pipeline_pb2.PipelineNode(
           node_info=pipeline_pb2.NodeInfo(
               type=metadata_store_pb2.ExecutionType(name='Docker_executor'))),
       pipeline_info=pipeline_pb2.PipelineInfo(id='test_pipeline_id'))
예제 #9
0
def _remove_dangling_downstream_nodes(
    node: p_pb2.PipelineNode,
    node_ids_to_keep: Collection[str]) -> p_pb2.PipelineNode:
  """Remove node.downstream_nodes that have been filtered out."""
  # Using a loop instead of set intersection to ensure the same order.
  downstream_nodes_to_keep = [
      downstream_node for downstream_node in node.downstream_nodes
      if downstream_node in node_ids_to_keep
  ]
  if len(downstream_nodes_to_keep) == len(node.downstream_nodes):
    return node
  result = p_pb2.PipelineNode()
  result.CopyFrom(node)
  result.downstream_nodes[:] = downstream_nodes_to_keep
  return result
예제 #10
0
 def _set_up_test_execution_info(self,
                                 input_dict=None,
                                 output_dict=None,
                                 exec_properties=None):
     return data_types.ExecutionInfo(
         execution_id=123,
         input_dict=input_dict or {},
         output_dict=output_dict or {},
         exec_properties=exec_properties or {},
         execution_output_uri='/testing/executor/output/',
         stateful_working_dir='/testing/stateful/dir',
         pipeline_node=pipeline_pb2.PipelineNode(
             node_info=pipeline_pb2.NodeInfo(
                 id='fakecomponent-fakecomponent')),
         pipeline_info=pipeline_pb2.PipelineInfo(id='Test'),
         pipeline_run_id='123')
예제 #11
0
 def _get_execution_info(self, input_dict, output_dict, exec_properties):
     pipeline_node = pipeline_pb2.PipelineNode(
         node_info={'id': 'MyPythonNode'})
     pipeline_info = pipeline_pb2.PipelineInfo(id='MyPipeline')
     stateful_working_dir = os.path.join(self.tmp_dir,
                                         'stateful_working_dir')
     executor_output_uri = os.path.join(self.tmp_dir, 'executor_output')
     return data_types.ExecutionInfo(
         execution_id=1,
         input_dict=input_dict,
         output_dict=output_dict,
         exec_properties=exec_properties,
         stateful_working_dir=stateful_working_dir,
         execution_output_uri=executor_output_uri,
         pipeline_node=pipeline_node,
         pipeline_info=pipeline_info,
         pipeline_run_id=99)
예제 #12
0
 def setUp(self):
   super().setUp()
   self._connection_config = metadata_store_pb2.ConnectionConfig()
   self._connection_config.sqlite.SetInParent()
   self._module_file_path = os.path.join(self.tmp_dir, 'module_file')
   self._input_artifacts = {'input_examples': [standard_artifacts.Examples()]}
   self._output_artifacts = {'output_models': [standard_artifacts.Model()]}
   self._parameters = {'module_file': self._module_file_path}
   self._module_file_content = 'module content'
   self._pipeline_node = text_format.Parse(
       """
       executor {
         python_class_executor_spec {class_path: 'a.b.c'}
       }
       """, pipeline_pb2.PipelineNode())
   self._executor_class_path = 'a.b.c'
   self._pipeline_info = pipeline_pb2.PipelineInfo(id='pipeline_id')
  def testExecutionInfoSerialization(self):
    my_artifact = _MyArtifact()
    my_artifact.int1 = 111

    execution_output_uri = 'output/uri'
    stateful_working_dir = 'workding/dir'
    exec_properties = {
        'property1': 'value1',
        'property2': 'value2',
    }
    pipeline_info = pipeline_pb2.PipelineInfo(id='my_pipeline')
    pipeline_node = text_format.Parse(
        """
        node_info {
          id: 'my_node'
        }
        """, pipeline_pb2.PipelineNode())

    original = data_types.ExecutionInfo(
        input_dict={'input': [my_artifact]},
        output_dict={'output': [my_artifact]},
        exec_properties=exec_properties,
        execution_output_uri=execution_output_uri,
        stateful_working_dir=stateful_working_dir,
        pipeline_info=pipeline_info,
        pipeline_node=pipeline_node)

    serialized = python_execution_binary_utils.serialize_execution_info(
        original)
    rehydrated = python_execution_binary_utils.deserialize_execution_info(
        serialized)

    self.CheckArtifactDict(rehydrated.input_dict, {'input': [my_artifact]})
    self.CheckArtifactDict(rehydrated.output_dict, {'output': [my_artifact]})
    self.assertEqual(rehydrated.exec_properties, exec_properties)
    self.assertEqual(rehydrated.execution_output_uri, execution_output_uri)
    self.assertEqual(rehydrated.stateful_working_dir, stateful_working_dir)
    self.assertProtoEquals(rehydrated.pipeline_info, original.pipeline_info)
    self.assertProtoEquals(rehydrated.pipeline_node, original.pipeline_node)
예제 #14
0
def _handle_missing_inputs(
    node: p_pb2.PipelineNode,
    node_ids_to_keep: Collection[str],
    pipeline_run_id_fn: Callable[[p_pb2.InputSpec.Channel], str],
) -> p_pb2.PipelineNode:
  """Private helper function to handle missing inputs.

  Args:
    node: The Pipeline node to check for missing inputs.
    node_ids_to_keep: The node_ids that are not filtered out.
    pipeline_run_id_fn: If this node has upstream nodes that are filtered out,
      this function would be used to obtain the pipeline_run_id for that input
      channel, which would then be provided as the 'pipeline_run_id' in the
      'pipeline_run' ContextQuery.

  Returns:
    A copy of the Pipeline node where all inputs that reference filtered-out
    nodes would have their 'pipeline_run' ContextQuery updated.
  """
  upstream_nodes_to_replace = set()
  upstream_nodes_to_keep = []
  for upstream_node in node.upstream_nodes:
    if upstream_node in node_ids_to_keep:
      upstream_nodes_to_keep.append(upstream_node)
    else:
      upstream_nodes_to_replace.add(upstream_node)

  if not upstream_nodes_to_replace:
    return node  # No parent missing, no need to change anything.

  result = p_pb2.PipelineNode()
  result.CopyFrom(node)
  for input_spec in result.inputs.inputs.values():
    for channel in input_spec.channels:
      if channel.producer_node_query.id in upstream_nodes_to_replace:
        pipeline_run_id = pipeline_run_id_fn(channel)
        _replace_pipeline_run_id_in_channel(channel, pipeline_run_id)
  result.upstream_nodes[:] = upstream_nodes_to_keep
  return result
예제 #15
0
 def testRunExecutorWithBeamPipelineArgs(self):
     executor_spec = text_format.Parse(
         """
   python_executor_spec: {
       class_path: "tfx.orchestration.portable.beam_executor_operator_test.ValidateBeamPipelineArgsExecutor"
   }
   beam_pipeline_args: "--runner=DirectRunner"
 """, executable_spec_pb2.BeamExecutableSpec())
     operator = beam_executor_operator.BeamExecutorOperator(executor_spec)
     pipeline_node = pipeline_pb2.PipelineNode(
         node_info={'id': 'MyBeamNode'})
     pipeline_info = pipeline_pb2.PipelineInfo(id='MyPipeline')
     executor_output_uri = os.path.join(self.tmp_dir, 'executor_output')
     executor_output = operator.run_executor(
         data_types.ExecutionInfo(
             execution_id=1,
             input_dict={'input_key': [standard_artifacts.Examples()]},
             output_dict={'output_key': [standard_artifacts.Model()]},
             exec_properties={},
             execution_output_uri=executor_output_uri,
             pipeline_node=pipeline_node,
             pipeline_info=pipeline_info,
             pipeline_run_id=99))
     self.assertProtoPartiallyEquals(
         """
       output_artifacts {
         key: "output_key"
         value {
           artifacts {
             custom_properties {
               key: "name"
               value {
                 string_value: "MyPipeline.MyBeamNode.my_model"
               }
             }
           }
         }
       }""", executor_output)
예제 #16
0
파일: driver_test.py 프로젝트: yifanmai/tfx
  def testRun(self):
    # Create input dir.
    self._input_base_path = os.path.join(self._test_dir, 'input_base')
    tf.io.gfile.makedirs(self._input_base_path)

    # Create PipelineInfo and PipelineNode
    pipeline_info = pipeline_pb2.PipelineInfo()
    pipeline_node = pipeline_pb2.PipelineNode()

    # Fake previous outputs
    span1_v1_split1 = os.path.join(self._input_base_path, 'span01', 'version01',
                                   'split1', 'data')
    io_utils.write_string_file(span1_v1_split1, 'testing11')
    span1_v1_split2 = os.path.join(self._input_base_path, 'span01', 'version01',
                                   'split2', 'data')
    io_utils.write_string_file(span1_v1_split2, 'testing12')

    ir_driver = driver.Driver(self._mock_metadata, pipeline_info, pipeline_node)
    example = standard_artifacts.Examples()

    # Prepare output_dic
    example.uri = 'my_uri'  # Will verify that this uri is not changed.
    output_dic = {utils.EXAMPLES_KEY: [example]}

    # Prepare output_dic exec_proterties.
    exec_properties = {
        utils.INPUT_BASE_KEY:
            self._input_base_path,
        utils.INPUT_CONFIG_KEY:
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='s1',
                        pattern='span{SPAN}/version{VERSION}/split1/*'),
                    example_gen_pb2.Input.Split(
                        name='s2',
                        pattern='span{SPAN}/version{VERSION}/split2/*')
                ]),
                preserving_proto_field_name=True),
    }
    result = ir_driver.run(None, output_dic, exec_properties)
    print(result)
    # Assert exec_properties' values
    exec_properties = result.exec_properties
    self.assertEqual(exec_properties[utils.SPAN_PROPERTY_NAME].int_value, 1)
    self.assertEqual(exec_properties[utils.VERSION_PROPERTY_NAME].int_value, 1)
    updated_input_config = example_gen_pb2.Input()
    json_format.Parse(exec_properties[utils.INPUT_CONFIG_KEY].string_value,
                      updated_input_config)
    self.assertProtoEquals(
        """
        splits {
          name: "s1"
          pattern: "span01/version01/split1/*"
        }
        splits {
          name: "s2"
          pattern: "span01/version01/split2/*"
        }""", updated_input_config)
    self.assertRegex(
        exec_properties[utils.FINGERPRINT_PROPERTY_NAME].string_value,
        r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*'
    )
    # Assert output_artifacts' values
    self.assertLen(result.output_artifacts[utils.EXAMPLES_KEY].artifacts, 1)
    output_example = result.output_artifacts[utils.EXAMPLES_KEY].artifacts[0]
    self.assertEqual(output_example.uri, example.uri)
    self.assertEqual(
        output_example.custom_properties[utils.SPAN_PROPERTY_NAME].string_value,
        '1')
    self.assertEqual(
        output_example.custom_properties[
            utils.VERSION_PROPERTY_NAME].string_value, '1')
    self.assertRegex(
        output_example.custom_properties[
            utils.FINGERPRINT_PROPERTY_NAME].string_value,
        r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*'
    )
예제 #17
0
    def testSuccess(self):
        with self._mlmd_connection as m:
            # Publishes two models which will be consumed by downstream resolver.
            output_model_1 = types.Artifact(
                self._my_trainer.outputs.outputs['model'].artifact_spec.type)
            output_model_1.uri = 'my_model_uri_1'

            output_model_2 = types.Artifact(
                self._my_trainer.outputs.outputs['model'].artifact_spec.type)
            output_model_2.uri = 'my_model_uri_2'

            contexts = context_lib.prepare_contexts(m,
                                                    self._my_trainer.contexts)
            execution = execution_publish_utils.register_execution(
                m, self._my_trainer.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'model': [output_model_1, output_model_2],
                })

        handler = resolver_node_handler.ResolverNodeHandler()
        execution_metadata = handler.run(
            mlmd_connection=self._mlmd_connection,
            pipeline_node=self._resolver_node,
            pipeline_info=self._pipeline_info,
            pipeline_runtime_spec=self._pipeline_runtime_spec)

        with self._mlmd_connection as m:
            # There is no way to directly verify the output artifact of the resolver
            # So here a fake downstream component is created which listens to the
            # resolver's output and we verify its input.
            down_stream_node = text_format.Parse(
                """
        inputs {
          inputs {
            key: "input_models"
            value {
              channels {
                producer_node_query {
                  id: "my_resolver"
                }
                context_queries {
                  type {
                    name: "pipeline"
                  }
                  name {
                    field_value {
                      string_value: "my_pipeline"
                    }
                  }
                }
                context_queries {
                  type {
                    name: "component"
                  }
                  name {
                    field_value {
                      string_value: "my_resolver"
                    }
                  }
                }
                artifact_query {
                  type {
                    name: "Model"
                  }
                }
                output_key: "models"
              }
              min_count: 1
            }
          }
        }
        upstream_nodes: "my_resolver"
        """, pipeline_pb2.PipelineNode())
            downstream_input_artifacts = inputs_utils.resolve_input_artifacts(
                metadata_handler=m, node_inputs=down_stream_node.inputs)
            downstream_input_model = downstream_input_artifacts['input_models']
            self.assertLen(downstream_input_model, 1)
            self.assertProtoPartiallyEquals(
                """
          id: 2
          type_id: 5
          uri: "my_model_uri_2"
          state: LIVE""",
                downstream_input_model[0].mlmd_artifact,
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])
            [execution] = m.store.get_executions_by_id([execution_metadata.id])

            self.assertProtoPartiallyEquals("""
          id: 2
          type_id: 6
          last_known_state: COMPLETE
          """,
                                            execution,
                                            ignored_fields=[
                                                'create_time_since_epoch',
                                                'last_update_time_since_epoch'
                                            ])
예제 #18
0
파일: compiler.py 프로젝트: jay90099/tfx
  def _compile_node(
      self,
      tfx_node: base_node.BaseNode,
      compile_context: _CompilerContext,
      deployment_config: pipeline_pb2.IntermediateDeploymentConfig,
      enable_cache: bool,
  ) -> pipeline_pb2.PipelineNode:
    """Compiles an individual TFX node into a PipelineNode proto.

    Args:
      tfx_node: A TFX node.
      compile_context: Resources needed to compile the node.
      deployment_config: Intermediate deployment config to set. Will include
        related specs for executors, drivers and platform specific configs.
      enable_cache: whether cache is enabled

    Raises:
      TypeError: When supplied tfx_node has values of invalid type.

    Returns:
      A PipelineNode proto that encodes information of the node.
    """
    node = pipeline_pb2.PipelineNode()

    # Step 1: Node info
    node.node_info.type.name = tfx_node.type
    if isinstance(tfx_node,
                  base_component.BaseComponent) and tfx_node.type_annotation:
      node.node_info.type.base_type = (
          tfx_node.type_annotation.MLMD_SYSTEM_BASE_TYPE)
    node.node_info.id = tfx_node.id

    # Step 2: Node Context
    # Context for the pipeline, across pipeline runs.
    pipeline_context_pb = node.contexts.contexts.add()
    pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
    pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name

    # Context for the current pipeline run.
    if compile_context.is_sync_mode:
      pipeline_run_context_pb = node.contexts.contexts.add()
      pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME
      compiler_utils.set_runtime_parameter_pb(
          pipeline_run_context_pb.name.runtime_parameter,
          constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)

    # Context for the node, across pipeline runs.
    node_context_pb = node.contexts.contexts.add()
    node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME
    node_context_pb.name.field_value.string_value = (
        compiler_utils.node_context_name(
            compile_context.pipeline_info.pipeline_context_name,
            node.node_info.id))

    # Pre Step 3: Alter graph topology if needed.
    if compile_context.is_async_mode:
      tfx_node_inputs = self._embed_upstream_resolver_nodes(
          compile_context, tfx_node, node)
    else:
      tfx_node_inputs = tfx_node.inputs

    # Step 3: Node inputs

    # Step 3.1: Generate implicit input channels
    # Step 3.1.1: Conditionals
    implicit_input_channels = {}
    predicates = conditional.get_predicates(tfx_node)
    if predicates:
      implicit_keys_map = {}
      for key, chnl in tfx_node_inputs.items():
        if not isinstance(chnl, types.Channel):
          raise ValueError(
              "Conditional only support using channel as a predicate.")
        implicit_keys_map[compiler_utils.implicit_channel_key(chnl)] = key
      encoded_predicates = []
      for predicate in predicates:
        for chnl in predicate.dependent_channels():
          implicit_key = compiler_utils.implicit_channel_key(chnl)
          if implicit_key not in implicit_keys_map:
            # Store this channel and add it to the node inputs later.
            implicit_input_channels[implicit_key] = chnl
        encoded_predicates.append(
            predicate.encode_with_keys(
                compiler_utils.build_channel_to_key_fn(implicit_keys_map)))
      # In async pipeline, conditional resolver step should be the last step
      # in all resolver steps of a node.
      resolver_step = node.inputs.resolver_config.resolver_steps.add()
      resolver_step.class_path = constants.CONDITIONAL_RESOLVER_CLASS_PATH
      resolver_step.config_json = json_utils.dumps(
          {"predicates": encoded_predicates})

    # Step 3.1.2: Add placeholder exec props to implicit_input_channels
    for key, value in tfx_node.exec_properties.items():
      if isinstance(value, placeholder.ChannelWrappedPlaceholder):
        if not (inspect.isclass(value.channel.type) and
                issubclass(value.channel.type, value_artifact.ValueArtifact)):
          raise ValueError("output channel to dynamic exec properties is not "
                           "ValueArtifact")
        implicit_key = compiler_utils.implicit_channel_key(value.channel)
        implicit_input_channels[implicit_key] = value.channel

    # Step 3.2: Handle ForEach.
    dsl_contexts = context_manager.get_contexts(tfx_node)
    for dsl_context in dsl_contexts:
      if isinstance(dsl_context, for_each.ForEachContext):
        for input_key, channel in tfx_node_inputs.items():
          if (isinstance(channel, types.LoopVarChannel) and
              channel.wrapped is dsl_context.wrapped_channel):
            node.inputs.resolver_config.resolver_steps.extend(
                _compile_for_each_context(input_key))
            break
        else:
          # Ideally should not reach here as the same check is performed at
          # ForEachContext.will_add_node().
          raise ValueError(
              f"Unable to locate ForEach loop variable {dsl_context.channel} "
              f"from inputs of node {tfx_node.id}.")

    # Check loop variable is used outside the ForEach.
    for input_key, channel in tfx_node_inputs.items():
      if isinstance(channel, types.LoopVarChannel):
        dsl_context_ids = {c.id for c in dsl_contexts}
        if channel.context_id not in dsl_context_ids:
          raise ValueError(
              "Loop variable cannot be used outside the ForEach "
              f"(node_id = {tfx_node.id}, input_key = {input_key}).")

    # Step 3.3: Fill node inputs
    for key, value in itertools.chain(tfx_node_inputs.items(),
                                      implicit_input_channels.items()):
      input_spec = node.inputs.inputs[key]
      for input_channel in channel_utils.get_individual_channels(value):
        chnl = input_spec.channels.add()

        # If the node input comes from another node's output, fill the context
        # queries with the producer node's contexts.
        if input_channel in compile_context.node_outputs:
          chnl.producer_node_query.id = input_channel.producer_component_id

          # Here we rely on pipeline.components to be topologically sorted.
          assert input_channel.producer_component_id in compile_context.node_pbs, (
              "producer component should have already been compiled.")
          producer_pb = compile_context.node_pbs[
              input_channel.producer_component_id]
          for producer_context in producer_pb.contexts.contexts:
            context_query = chnl.context_queries.add()
            context_query.type.CopyFrom(producer_context.type)
            context_query.name.CopyFrom(producer_context.name)

        # If the node input does not come from another node's output, fill the
        # context queries based on Channel info. We requires every channel to
        # have pipeline context and will fill it automatically.
        else:
          # Add pipeline context query.
          context_query = chnl.context_queries.add()
          context_query.type.CopyFrom(pipeline_context_pb.type)
          context_query.name.CopyFrom(pipeline_context_pb.name)

          # Optionally add node context query.
          if input_channel.producer_component_id:
            # Add node context query if `producer_component_id` is present.
            chnl.producer_node_query.id = input_channel.producer_component_id
            node_context_query = chnl.context_queries.add()
            node_context_query.type.name = constants.NODE_CONTEXT_TYPE_NAME
            node_context_query.name.field_value.string_value = "{}.{}".format(
                compile_context.pipeline_info.pipeline_context_name,
                input_channel.producer_component_id)

        artifact_type = input_channel.type._get_artifact_type()  # pylint: disable=protected-access
        chnl.artifact_query.type.CopyFrom(artifact_type)
        chnl.artifact_query.type.ClearField("properties")

        if input_channel.output_key:
          chnl.output_key = input_channel.output_key

        # Set NodeInputs.min_count.
        if isinstance(tfx_node, base_component.BaseComponent):
          if key in implicit_input_channels:
            # Mark all input channel as optional for implicit inputs
            # (e.g. conditionals). This is suboptimal, but still a safe guess to
            # avoid breaking the pipeline run.
            input_spec.min_count = 0
          else:
            try:
              # Calculating min_count from ComponentSpec.INPUTS.
              if tfx_node.spec.is_optional_input(key):
                input_spec.min_count = 0
              else:
                input_spec.min_count = 1
            except KeyError:
              # Currently we can fall here if the upstream resolver node inputs
              # are embedded into the current node (in async mode). We always
              # regard resolver's inputs as optional.
              if compile_context.is_async_mode:
                input_spec.min_count = 0
              else:
                raise

    # TODO(b/170694459): Refactor special nodes as plugins.
    # Step 3.4: Special treatment for Resolver node.
    if compiler_utils.is_resolver(tfx_node):
      assert compile_context.is_sync_mode
      node.inputs.resolver_config.resolver_steps.extend(
          _compile_resolver_node(tfx_node))

    # Step 4: Node outputs
    for key, value in tfx_node.outputs.items():
      # Register the output in the context.
      compile_context.node_outputs.add(value)
    if (isinstance(tfx_node, base_component.BaseComponent) or
        compiler_utils.is_importer(tfx_node)):
      self._compile_node_outputs(tfx_node, node)

    # Step 5: Node parameters
    if not compiler_utils.is_resolver(tfx_node):
      for key, value in tfx_node.exec_properties.items():
        if value is None:
          continue
        parameter_value = node.parameters.parameters[key]

        # Order matters, because runtime parameter can be in serialized string.
        if isinstance(value, data_types.RuntimeParameter):
          compiler_utils.set_runtime_parameter_pb(
              parameter_value.runtime_parameter, value.name, value.ptype,
              value.default)
        # RuntimeInfoPlaceholder passes Execution parameters of Facade
        # components.
        elif isinstance(value, placeholder.RuntimeInfoPlaceholder):
          parameter_value.placeholder.CopyFrom(value.encode())
        # ChannelWrappedPlaceholder passes dynamic execution parameter.
        elif isinstance(value, placeholder.ChannelWrappedPlaceholder):
          compiler_utils.validate_dynamic_exec_ph_operator(value)
          parameter_value.placeholder.CopyFrom(
              value.encode_with_keys(compiler_utils.implicit_channel_key))
        else:
          try:
            data_types_utils.set_parameter_value(parameter_value, value)
          except ValueError:
            raise ValueError(
                "Component {} got unsupported parameter {} with type {}."
                .format(tfx_node.id, key, type(value))) from ValueError

    # Step 6: Executor spec and optional driver spec for components
    if isinstance(tfx_node, base_component.BaseComponent):
      executor_spec = tfx_node.executor_spec.encode(
          component_spec=tfx_node.spec)
      deployment_config.executor_specs[tfx_node.id].Pack(executor_spec)

      # TODO(b/163433174): Remove specialized logic once generalization of
      # driver spec is done.
      if tfx_node.driver_class != base_driver.BaseDriver:
        driver_class_path = _fully_qualified_name(tfx_node.driver_class)
        driver_spec = executable_spec_pb2.PythonClassExecutableSpec()
        driver_spec.class_path = driver_class_path
        deployment_config.custom_driver_specs[tfx_node.id].Pack(driver_spec)

    # Step 7: Upstream/Downstream nodes
    node.upstream_nodes.extend(
        self._find_runtime_upstream_node_ids(compile_context, tfx_node))
    node.downstream_nodes.extend(
        self._find_runtime_downstream_node_ids(compile_context, tfx_node))

    # Step 8: Node execution options
    node.execution_options.caching_options.enable_cache = enable_cache

    # Step 9: Per-node platform config
    if isinstance(tfx_node, base_component.BaseComponent):
      tfx_component = cast(base_component.BaseComponent, tfx_node)
      if tfx_component.platform_config:
        deployment_config.node_level_platform_configs[tfx_node.id].Pack(
            tfx_component.platform_config)

    return node
예제 #19
0
파일: compiler.py 프로젝트: ragnariock/tfx
    def _compile_node(
        self, tfx_node: base_node.BaseNode, compile_context: _CompilerContext,
        deployment_config: pipeline_pb2.IntermediateDeploymentConfig
    ) -> pipeline_pb2.PipelineNode:
        """Compiles an individual TFX node into a PipelineNode proto.

    Args:
      tfx_node: A TFX node.
      compile_context: Resources needed to compile the node.
      deployment_config: Intermediate deployment config to set. Will include
        related specs for executors, drivers and platform specific configs.

    Returns:
      A PipelineNode proto that encodes information of the node.
    """
        node = pipeline_pb2.PipelineNode()

        # Step 1: Node info
        node.node_info.type.name = tfx_node.type
        node.node_info.id = tfx_node.id

        # Step 2: Node Context
        # Context for the pipeline, across pipeline runs.
        pipeline_context_pb = node.contexts.contexts.add()
        pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
        pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name

        # Context for the current pipeline run.
        if (compile_context.execution_mode ==
                pipeline_pb2.Pipeline.ExecutionMode.SYNC):
            pipeline_run_context_pb = node.contexts.contexts.add()
            pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME
            compiler_utils.set_runtime_parameter_pb(
                pipeline_run_context_pb.name.runtime_parameter,
                constants.PIPELINE_RUN_CONTEXT_TYPE_NAME, str)

        # Context for the node, across pipeline runs.
        node_context_pb = node.contexts.contexts.add()
        node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME
        node_context_pb.name.field_value.string_value = "{}.{}".format(
            compile_context.pipeline_info.pipeline_context_name,
            node.node_info.id)

        # Step 3: Node inputs
        for key, value in tfx_node.inputs.items():
            input_spec = node.inputs.inputs[key]
            channel = input_spec.channels.add()
            if value.producer_component_id:
                channel.producer_node_query.id = value.producer_component_id

                # Here we rely on pipeline.components to be topologically sorted.
                assert value.producer_component_id in compile_context.node_pbs, (
                    "producer component should have already been compiled.")
                producer_pb = compile_context.node_pbs[
                    value.producer_component_id]
                for producer_context in producer_pb.contexts.contexts:
                    if (not compiler_utils.is_resolver(tfx_node)
                            or producer_context.name.runtime_parameter.name !=
                            constants.PIPELINE_RUN_CONTEXT_TYPE_NAME):
                        context_query = channel.context_queries.add()
                        context_query.type.CopyFrom(producer_context.type)
                        context_query.name.CopyFrom(producer_context.name)

            artifact_type = value.type._get_artifact_type()  # pylint: disable=protected-access
            channel.artifact_query.type.CopyFrom(artifact_type)
            channel.artifact_query.type.ClearField("properties")

            if value.output_key:
                channel.output_key = value.output_key

            # TODO(b/158712886): Calculate min_count based on if inputs are optional.
            # min_count = 0 stands for optional input and 1 stands for required input.

        # Step 3.1: Special treatment for Resolver node
        if compiler_utils.is_resolver(tfx_node):
            resolver = tfx_node.exec_properties[resolver_node.RESOLVER_CLASS]
            if resolver == latest_artifacts_resolver.LatestArtifactsResolver:
                node.inputs.resolver_config.resolver_policy = (
                    pipeline_pb2.ResolverConfig.ResolverPolicy.LATEST_ARTIFACT)
            elif resolver == latest_blessed_model_resolver.LatestBlessedModelResolver:
                node.inputs.resolver_config.resolver_policy = (
                    pipeline_pb2.ResolverConfig.ResolverPolicy.
                    LATEST_BLESSED_MODEL)
            else:
                raise ValueError("Got unsupported resolver policy: {}".format(
                    resolver.type))

        # Step 4: Node outputs
        if compiler_utils.is_component(tfx_node):
            for key, value in tfx_node.outputs.items():
                output_spec = node.outputs.outputs[key]
                artifact_type = value.type._get_artifact_type()  # pylint: disable=protected-access
                output_spec.artifact_spec.type.CopyFrom(artifact_type)

        # Step 5: Node parameters
        if not compiler_utils.is_resolver(tfx_node):
            for key, value in tfx_node.exec_properties.items():
                if value is None:
                    continue
                parameter_value = node.parameters.parameters[key]

                # Order matters, because runtime parameter can be in serialized string.
                if isinstance(value, data_types.RuntimeParameter):
                    compiler_utils.set_runtime_parameter_pb(
                        parameter_value.runtime_parameter, value.name,
                        value.ptype, value.default)
                elif isinstance(value, str) and re.search(
                        data_types.RUNTIME_PARAMETER_PATTERN, value):
                    runtime_param = json.loads(value)
                    compiler_utils.set_runtime_parameter_pb(
                        parameter_value.runtime_parameter, runtime_param.name,
                        runtime_param.ptype, runtime_param.default)
                elif isinstance(value, str):
                    parameter_value.field_value.string_value = value
                elif isinstance(value, int):
                    parameter_value.field_value.int_value = value
                elif isinstance(value, float):
                    parameter_value.field_value.double_value = value
                else:
                    raise ValueError(
                        "Component {} got unsupported parameter {} with type {}."
                        .format(tfx_node.id, key, type(value)))

        # Step 6: Executor spec and optional driver spec for components
        if compiler_utils.is_component(tfx_node):
            executor_spec = tfx_node.executor_spec.encode()
            deployment_config.executor_specs[tfx_node.id].Pack(executor_spec)

            # TODO(b/163433174): Remove specialized logic once generalization of
            # driver spec is done.
            if tfx_node.driver_class != base_driver.BaseDriver:
                driver_class_path = "{}.{}".format(
                    tfx_node.driver_class.__module__,
                    tfx_node.driver_class.__name__)
                driver_spec = executable_spec_pb2.PythonClassExecutableSpec()
                driver_spec.class_path = driver_class_path
                deployment_config.custom_driver_specs[tfx_node.id].Pack(
                    driver_spec)

        # Step 7: Upstream/Downstream nodes
        # Note: the order of tfx_node.upstream_nodes is inconsistent from
        # run to run. We sort them so that compiler generates consistent results.
        node.upstream_nodes.extend(
            sorted([
                upstream_component.id
                for upstream_component in tfx_node.upstream_nodes
            ]))
        node.downstream_nodes.extend(
            sorted([
                downstream_component.id
                for downstream_component in tfx_node.downstream_nodes
            ]))

        # Step 8: Node execution options
        # TODO(kennethyang): Add support for node execution options.

        return node
예제 #20
0
파일: compiler.py 프로젝트: jasonz1112/tfx
    def _compile_node(
        self,
        tfx_node: base_node.BaseNode,
        compile_context: _CompilerContext,
        deployment_config: pipeline_pb2.IntermediateDeploymentConfig,
        enable_cache: bool,
    ) -> pipeline_pb2.PipelineNode:
        """Compiles an individual TFX node into a PipelineNode proto.

    Args:
      tfx_node: A TFX node.
      compile_context: Resources needed to compile the node.
      deployment_config: Intermediate deployment config to set. Will include
        related specs for executors, drivers and platform specific configs.
      enable_cache: whether cache is enabled

    Raises:
      TypeError: When supplied tfx_node has values of invalid type.

    Returns:
      A PipelineNode proto that encodes information of the node.
    """
        node = pipeline_pb2.PipelineNode()

        # Step 1: Node info
        node.node_info.type.name = tfx_node.type
        node.node_info.id = tfx_node.id

        # Step 2: Node Context
        # Context for the pipeline, across pipeline runs.
        pipeline_context_pb = node.contexts.contexts.add()
        pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
        pipeline_context_pb.name.field_value.string_value = compile_context.pipeline_info.pipeline_context_name

        # Context for the current pipeline run.
        if compile_context.is_sync_mode:
            pipeline_run_context_pb = node.contexts.contexts.add()
            pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME
            compiler_utils.set_runtime_parameter_pb(
                pipeline_run_context_pb.name.runtime_parameter,
                constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)

        # Context for the node, across pipeline runs.
        node_context_pb = node.contexts.contexts.add()
        node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME
        node_context_pb.name.field_value.string_value = "{}.{}".format(
            compile_context.pipeline_info.pipeline_context_name,
            node.node_info.id)

        # Pre Step 3: Alter graph topology if needed.
        if compile_context.is_async_mode:
            tfx_node_inputs = self._compile_resolver_config(
                compile_context, tfx_node, node)
        else:
            tfx_node_inputs = tfx_node.inputs

        # Step 3: Node inputs
        for key, value in tfx_node_inputs.items():
            input_spec = node.inputs.inputs[key]
            channel = input_spec.channels.add()
            if value.producer_component_id:
                channel.producer_node_query.id = value.producer_component_id

                # Here we rely on pipeline.components to be topologically sorted.
                assert value.producer_component_id in compile_context.node_pbs, (
                    "producer component should have already been compiled.")
                producer_pb = compile_context.node_pbs[
                    value.producer_component_id]
                for producer_context in producer_pb.contexts.contexts:
                    if (not compiler_utils.is_resolver(tfx_node)
                            or producer_context.name.runtime_parameter.name !=
                            constants.PIPELINE_RUN_CONTEXT_TYPE_NAME):
                        context_query = channel.context_queries.add()
                        context_query.type.CopyFrom(producer_context.type)
                        context_query.name.CopyFrom(producer_context.name)
            else:
                # Caveat: portable core requires every channel to have at least one
                # Contex. But For cases like system nodes and producer-consumer
                # pipelines, a channel may not have contexts at all. In these cases,
                # we want to use the pipeline level context as the input channel
                # context.
                context_query = channel.context_queries.add()
                context_query.type.CopyFrom(pipeline_context_pb.type)
                context_query.name.CopyFrom(pipeline_context_pb.name)

            artifact_type = value.type._get_artifact_type()  # pylint: disable=protected-access
            channel.artifact_query.type.CopyFrom(artifact_type)
            channel.artifact_query.type.ClearField("properties")

            if value.output_key:
                channel.output_key = value.output_key

            # TODO(b/158712886): Calculate min_count based on if inputs are optional.
            # min_count = 0 stands for optional input and 1 stands for required input.

        # Step 3.1: Special treatment for Resolver node.
        if compiler_utils.is_resolver(tfx_node):
            assert compile_context.is_sync_mode
            node.inputs.resolver_config.resolver_steps.extend(
                _convert_to_resolver_steps(tfx_node))

        # Step 4: Node outputs
        if isinstance(tfx_node, base_component.BaseComponent):
            for key, value in tfx_node.outputs.items():
                output_spec = node.outputs.outputs[key]
                artifact_type = value.type._get_artifact_type()  # pylint: disable=protected-access
                output_spec.artifact_spec.type.CopyFrom(artifact_type)
                for prop_key, prop_value in value.additional_properties.items(
                ):
                    _check_property_value_type(prop_key, prop_value,
                                               output_spec.artifact_spec.type)
                    data_types_utils.set_metadata_value(
                        output_spec.artifact_spec.
                        additional_properties[prop_key].field_value,
                        prop_value)
                for prop_key, prop_value in value.additional_custom_properties.items(
                ):
                    data_types_utils.set_metadata_value(
                        output_spec.artifact_spec.
                        additional_custom_properties[prop_key].field_value,
                        prop_value)

        # TODO(b/170694459): Refactor special nodes as plugins.
        # Step 4.1: Special treament for Importer node
        if compiler_utils.is_importer(tfx_node):
            self._compile_importer_node_outputs(tfx_node, node)

        # Step 5: Node parameters
        if not compiler_utils.is_resolver(tfx_node):
            for key, value in tfx_node.exec_properties.items():
                if value is None:
                    continue
                # Ignore following two properties for a importer node, because they are
                # already attached to the artifacts produced by the importer node.
                if compiler_utils.is_importer(tfx_node) and (
                        key == importer.PROPERTIES_KEY
                        or key == importer.CUSTOM_PROPERTIES_KEY):
                    continue
                parameter_value = node.parameters.parameters[key]

                # Order matters, because runtime parameter can be in serialized string.
                if isinstance(value, data_types.RuntimeParameter):
                    compiler_utils.set_runtime_parameter_pb(
                        parameter_value.runtime_parameter, value.name,
                        value.ptype, value.default)
                elif isinstance(value, str) and re.search(
                        data_types.RUNTIME_PARAMETER_PATTERN, value):
                    runtime_param = json.loads(value)
                    compiler_utils.set_runtime_parameter_pb(
                        parameter_value.runtime_parameter, runtime_param.name,
                        runtime_param.ptype, runtime_param.default)
                else:
                    try:
                        data_types_utils.set_metadata_value(
                            parameter_value.field_value, value)
                    except ValueError:
                        raise ValueError(
                            "Component {} got unsupported parameter {} with type {}."
                            .format(tfx_node.id, key, type(value)))

        # Step 6: Executor spec and optional driver spec for components
        if isinstance(tfx_node, base_component.BaseComponent):
            executor_spec = tfx_node.executor_spec.encode(
                component_spec=tfx_node.spec)
            deployment_config.executor_specs[tfx_node.id].Pack(executor_spec)

            # TODO(b/163433174): Remove specialized logic once generalization of
            # driver spec is done.
            if tfx_node.driver_class != base_driver.BaseDriver:
                driver_class_path = "{}.{}".format(
                    tfx_node.driver_class.__module__,
                    tfx_node.driver_class.__name__)
                driver_spec = executable_spec_pb2.PythonClassExecutableSpec()
                driver_spec.class_path = driver_class_path
                deployment_config.custom_driver_specs[tfx_node.id].Pack(
                    driver_spec)

        # Step 7: Upstream/Downstream nodes
        # Note: the order of tfx_node.upstream_nodes is inconsistent from
        # run to run. We sort them so that compiler generates consistent results.
        # For ASYNC mode upstream/downstream node information is not set as
        # compiled IR graph topology can be different from that on pipeline
        # authoring time; for example ResolverNode is removed.
        if compile_context.is_sync_mode:
            node.upstream_nodes.extend(
                sorted(node.id for node in tfx_node.upstream_nodes))
            node.downstream_nodes.extend(
                sorted(node.id for node in tfx_node.downstream_nodes))

        # Step 8: Node execution options
        node.execution_options.caching_options.enable_cache = enable_cache

        # Step 9: Per-node platform config
        if isinstance(tfx_node, base_component.BaseComponent):
            tfx_component = cast(base_component.BaseComponent, tfx_node)
            if tfx_component.platform_config:
                deployment_config.node_level_platform_configs[
                    tfx_node.id].Pack(tfx_component.platform_config)

        return node
예제 #21
0
        }
      }
    }
   outputs {
      key: "output_3"
      value {
        artifact_spec {
          type {
            id: 3
            name: "String"
          }
        }
      }
    }
  }
""", pipeline_pb2.PipelineNode())


class OutputUtilsTest(test_case_utils.TfxTest, parameterized.TestCase):
    def setUp(self):
        super().setUp()
        pipeline_runtime_spec = pipeline_pb2.PipelineRuntimeSpec()
        pipeline_runtime_spec.pipeline_root.field_value.string_value = self.tmp_dir
        pipeline_runtime_spec.pipeline_run_id.field_value.string_value = (
            'test_run_0')
        self._pipeline_runtime_spec = pipeline_runtime_spec

    def _output_resolver(self, execution_mode=pipeline_pb2.Pipeline.SYNC):
        return outputs_utils.OutputsResolver(
            pipeline_node=_PIPELINE_NODE,
            pipeline_info=_PIPELINE_INFO,