예제 #1
0
 def setUp(self):
     super(PlaceholderUtilsTest, self).setUp()
     examples = [standard_artifacts.Examples()]
     examples[0].uri = "/tmp"
     examples[0].split_names = artifact_utils.encode_split_names(
         ["train", "eval"])
     self._serving_spec = infra_validator_pb2.ServingSpec()
     self._serving_spec.tensorflow_serving.tags.extend(
         ["latest", "1.15.0-gpu"])
     self._resolution_context = placeholder_utils.ResolutionContext(
         exec_info=data_types.ExecutionInfo(
             input_dict={
                 "model": [standard_artifacts.Model()],
                 "examples": examples,
             },
             output_dict={"blessing": [standard_artifacts.ModelBlessing()]},
             exec_properties={
                 "proto_property":
                 json_format.MessageToJson(message=self._serving_spec,
                                           sort_keys=True,
                                           preserving_proto_field_name=True,
                                           indent=0)
             },
             execution_output_uri="test_executor_output_uri",
             stateful_working_dir="test_stateful_working_dir",
             pipeline_node=pipeline_pb2.PipelineNode(
                 node_info=pipeline_pb2.NodeInfo(
                     type=metadata_store_pb2.ExecutionType(
                         name="infra_validator"))),
             pipeline_info=pipeline_pb2.PipelineInfo(
                 id="test_pipeline_id")),
         executor_spec=executable_spec_pb2.PythonClassExecutableSpec(
             class_path="test_class_path"),
     )
     # Resolution context to simulate missing optional values.
     self._none_resolution_context = placeholder_utils.ResolutionContext(
         exec_info=data_types.ExecutionInfo(
             input_dict={
                 "model": [],
                 "examples": [],
             },
             output_dict={"blessing": []},
             exec_properties={},
             pipeline_node=pipeline_pb2.PipelineNode(
                 node_info=pipeline_pb2.NodeInfo(
                     type=metadata_store_pb2.ExecutionType(
                         name="infra_validator"))),
             pipeline_info=pipeline_pb2.PipelineInfo(
                 id="test_pipeline_id")),
         executor_spec=None,
         platform_config=None)
예제 #2
0
 def testGetCacheContextTwiceDifferentPipelineInfo(self):
     with metadata.Metadata(connection_config=self._connection_config) as m:
         self._get_cache_context(m)
         self._get_cache_context(
             m, custom_pipeline_info=pipeline_pb2.PipelineInfo(id='new_id'))
         # Different pipeline info will result in new cache context.
         self.assertLen(m.store.get_contexts(), 2)
 def _set_up_test_execution_info(self,
                                 input_dict=None,
                                 output_dict=None,
                                 exec_properties=None):
   return data_types.ExecutionInfo(
       input_dict=input_dict or {},
       output_dict=output_dict or {},
       exec_properties=exec_properties or {},
       execution_output_uri='/testing/executor/output/',
       stateful_working_dir='/testing/stateful/dir',
       pipeline_node=pipeline_pb2.PipelineNode(
           node_info=pipeline_pb2.NodeInfo(
               type=metadata_store_pb2.ExecutionType(name='Docker_executor'))),
       pipeline_info=pipeline_pb2.PipelineInfo(id='test_pipeline_id'))
예제 #4
0
 def _set_up_test_execution_info(self,
                                 input_dict=None,
                                 output_dict=None,
                                 exec_properties=None):
     return data_types.ExecutionInfo(
         execution_id=123,
         input_dict=input_dict or {},
         output_dict=output_dict or {},
         exec_properties=exec_properties or {},
         execution_output_uri='/testing/executor/output/',
         stateful_working_dir='/testing/stateful/dir',
         pipeline_node=pipeline_pb2.PipelineNode(
             node_info=pipeline_pb2.NodeInfo(
                 id='fakecomponent-fakecomponent')),
         pipeline_info=pipeline_pb2.PipelineInfo(id='Test'),
         pipeline_run_id='123')
예제 #5
0
 def _get_execution_info(self, input_dict, output_dict, exec_properties):
     pipeline_node = pipeline_pb2.PipelineNode(
         node_info={'id': 'MyPythonNode'})
     pipeline_info = pipeline_pb2.PipelineInfo(id='MyPipeline')
     stateful_working_dir = os.path.join(self.tmp_dir,
                                         'stateful_working_dir')
     executor_output_uri = os.path.join(self.tmp_dir, 'executor_output')
     return data_types.ExecutionInfo(
         execution_id=1,
         input_dict=input_dict,
         output_dict=output_dict,
         exec_properties=exec_properties,
         stateful_working_dir=stateful_working_dir,
         execution_output_uri=executor_output_uri,
         pipeline_node=pipeline_node,
         pipeline_info=pipeline_info,
         pipeline_run_id=99)
예제 #6
0
 def setUp(self):
   super().setUp()
   self._connection_config = metadata_store_pb2.ConnectionConfig()
   self._connection_config.sqlite.SetInParent()
   self._module_file_path = os.path.join(self.tmp_dir, 'module_file')
   self._input_artifacts = {'input_examples': [standard_artifacts.Examples()]}
   self._output_artifacts = {'output_models': [standard_artifacts.Model()]}
   self._parameters = {'module_file': self._module_file_path}
   self._module_file_content = 'module content'
   self._pipeline_node = text_format.Parse(
       """
       executor {
         python_class_executor_spec {class_path: 'a.b.c'}
       }
       """, pipeline_pb2.PipelineNode())
   self._executor_class_path = 'a.b.c'
   self._pipeline_info = pipeline_pb2.PipelineInfo(id='pipeline_id')
  def testExecutionInfoSerialization(self):
    my_artifact = _MyArtifact()
    my_artifact.int1 = 111

    execution_output_uri = 'output/uri'
    stateful_working_dir = 'workding/dir'
    exec_properties = {
        'property1': 'value1',
        'property2': 'value2',
    }
    pipeline_info = pipeline_pb2.PipelineInfo(id='my_pipeline')
    pipeline_node = text_format.Parse(
        """
        node_info {
          id: 'my_node'
        }
        """, pipeline_pb2.PipelineNode())

    original = data_types.ExecutionInfo(
        input_dict={'input': [my_artifact]},
        output_dict={'output': [my_artifact]},
        exec_properties=exec_properties,
        execution_output_uri=execution_output_uri,
        stateful_working_dir=stateful_working_dir,
        pipeline_info=pipeline_info,
        pipeline_node=pipeline_node)

    serialized = python_execution_binary_utils.serialize_execution_info(
        original)
    rehydrated = python_execution_binary_utils.deserialize_execution_info(
        serialized)

    self.CheckArtifactDict(rehydrated.input_dict, {'input': [my_artifact]})
    self.CheckArtifactDict(rehydrated.output_dict, {'output': [my_artifact]})
    self.assertEqual(rehydrated.exec_properties, exec_properties)
    self.assertEqual(rehydrated.execution_output_uri, execution_output_uri)
    self.assertEqual(rehydrated.stateful_working_dir, stateful_working_dir)
    self.assertProtoEquals(rehydrated.pipeline_info, original.pipeline_info)
    self.assertProtoEquals(rehydrated.pipeline_node, original.pipeline_node)
예제 #8
0
 def testRunExecutorWithBeamPipelineArgs(self):
     executor_spec = text_format.Parse(
         """
   python_executor_spec: {
       class_path: "tfx.orchestration.portable.beam_executor_operator_test.ValidateBeamPipelineArgsExecutor"
   }
   beam_pipeline_args: "--runner=DirectRunner"
 """, executable_spec_pb2.BeamExecutableSpec())
     operator = beam_executor_operator.BeamExecutorOperator(executor_spec)
     pipeline_node = pipeline_pb2.PipelineNode(
         node_info={'id': 'MyBeamNode'})
     pipeline_info = pipeline_pb2.PipelineInfo(id='MyPipeline')
     executor_output_uri = os.path.join(self.tmp_dir, 'executor_output')
     executor_output = operator.run_executor(
         data_types.ExecutionInfo(
             execution_id=1,
             input_dict={'input_key': [standard_artifacts.Examples()]},
             output_dict={'output_key': [standard_artifacts.Model()]},
             exec_properties={},
             execution_output_uri=executor_output_uri,
             pipeline_node=pipeline_node,
             pipeline_info=pipeline_info,
             pipeline_run_id=99))
     self.assertProtoPartiallyEquals(
         """
       output_artifacts {
         key: "output_key"
         value {
           artifacts {
             custom_properties {
               key: "name"
               value {
                 string_value: "MyPipeline.MyBeamNode.my_model"
               }
             }
           }
         }
       }""", executor_output)
예제 #9
0
파일: driver_test.py 프로젝트: yifanmai/tfx
  def testRun(self):
    # Create input dir.
    self._input_base_path = os.path.join(self._test_dir, 'input_base')
    tf.io.gfile.makedirs(self._input_base_path)

    # Create PipelineInfo and PipelineNode
    pipeline_info = pipeline_pb2.PipelineInfo()
    pipeline_node = pipeline_pb2.PipelineNode()

    # Fake previous outputs
    span1_v1_split1 = os.path.join(self._input_base_path, 'span01', 'version01',
                                   'split1', 'data')
    io_utils.write_string_file(span1_v1_split1, 'testing11')
    span1_v1_split2 = os.path.join(self._input_base_path, 'span01', 'version01',
                                   'split2', 'data')
    io_utils.write_string_file(span1_v1_split2, 'testing12')

    ir_driver = driver.Driver(self._mock_metadata, pipeline_info, pipeline_node)
    example = standard_artifacts.Examples()

    # Prepare output_dic
    example.uri = 'my_uri'  # Will verify that this uri is not changed.
    output_dic = {utils.EXAMPLES_KEY: [example]}

    # Prepare output_dic exec_proterties.
    exec_properties = {
        utils.INPUT_BASE_KEY:
            self._input_base_path,
        utils.INPUT_CONFIG_KEY:
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='s1',
                        pattern='span{SPAN}/version{VERSION}/split1/*'),
                    example_gen_pb2.Input.Split(
                        name='s2',
                        pattern='span{SPAN}/version{VERSION}/split2/*')
                ]),
                preserving_proto_field_name=True),
    }
    result = ir_driver.run(None, output_dic, exec_properties)
    print(result)
    # Assert exec_properties' values
    exec_properties = result.exec_properties
    self.assertEqual(exec_properties[utils.SPAN_PROPERTY_NAME].int_value, 1)
    self.assertEqual(exec_properties[utils.VERSION_PROPERTY_NAME].int_value, 1)
    updated_input_config = example_gen_pb2.Input()
    json_format.Parse(exec_properties[utils.INPUT_CONFIG_KEY].string_value,
                      updated_input_config)
    self.assertProtoEquals(
        """
        splits {
          name: "s1"
          pattern: "span01/version01/split1/*"
        }
        splits {
          name: "s2"
          pattern: "span01/version01/split2/*"
        }""", updated_input_config)
    self.assertRegex(
        exec_properties[utils.FINGERPRINT_PROPERTY_NAME].string_value,
        r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*'
    )
    # Assert output_artifacts' values
    self.assertLen(result.output_artifacts[utils.EXAMPLES_KEY].artifacts, 1)
    output_example = result.output_artifacts[utils.EXAMPLES_KEY].artifacts[0]
    self.assertEqual(output_example.uri, example.uri)
    self.assertEqual(
        output_example.custom_properties[utils.SPAN_PROPERTY_NAME].string_value,
        '1')
    self.assertEqual(
        output_example.custom_properties[
            utils.VERSION_PROPERTY_NAME].string_value, '1')
    self.assertRegex(
        output_example.custom_properties[
            utils.FINGERPRINT_PROPERTY_NAME].string_value,
        r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*'
    )
예제 #10
0
from absl.testing import parameterized
import tensorflow as tf
from tfx.dsl.io import fileio
from tfx.orchestration.portable import outputs_utils
from tfx.proto.orchestration import execution_result_pb2
from tfx.proto.orchestration import pipeline_pb2
from tfx.types import standard_artifacts
from tfx.types.value_artifact import ValueArtifact
from tfx.utils import test_case_utils

from google.protobuf import text_format

_PIPELINE_INFO = text_format.Parse("""
  id: "test_pipeline"
""", pipeline_pb2.PipelineInfo())

_PIPELINE_NODE = text_format.Parse(
    """
  node_info {
    id: "test_node"
  }
  outputs {
    outputs {
      key: "output_1"
      value {
        artifact_spec {
          type {
            id: 1
            name: "test_type_1"
            properties {