Example #1
0
 def testMainEmptyInputs(self):
     """Test executor class import under empty inputs/outputs."""
     inputs = {
         'x': [
             standard_artifacts.ExternalArtifact(),
             standard_artifacts.ExternalArtifact()
         ]
     }
     outputs = {'y': [standard_artifacts.Examples()]}
     exec_properties = {'a': 'b'}
     args = [
         '--executor_class_path=%s.%s' %
         (FakeExecutor.__module__, FakeExecutor.__name__),
         '--inputs=%s' % artifact_utils.jsonify_artifact_dict(inputs),
         '--outputs=%s' % artifact_utils.jsonify_artifact_dict(outputs),
         '--exec-properties=%s' % json.dumps(exec_properties),
     ]
     with ArgsCapture() as args_capture:
         run_executor.main(args)
         # TODO(b/131417512): Add equal comparison to types.Artifact class so we
         # can use asserters.
         self.assertSetEqual(set(args_capture.input_dict.keys()),
                             set(inputs.keys()))
         self.assertSetEqual(set(args_capture.output_dict.keys()),
                             set(outputs.keys()))
         self.assertDictEqual(args_capture.exec_properties, exec_properties)
Example #2
0
  def setUp(self):
    super(BaseComponentWithPipelineParamTest, self).setUp()

    test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')
    example_gen_buckets = data_types.RuntimeParameter(
        name='example-gen-buckets', ptype=int, default=10)

    examples = standard_artifacts.ExternalArtifact()
    example_gen = csv_example_gen_component.CsvExampleGen(
        input=channel_utils.as_channel([examples]),
        output_config={
            'split_config': {
                'splits': [{
                    'name': 'examples',
                    'hash_buckets': example_gen_buckets
                }]
            }
        })
    statistics_gen = statistics_gen_component.StatisticsGen(
        examples=example_gen.outputs['examples'], instance_name='foo')

    pipeline = tfx_pipeline.Pipeline(
        pipeline_name=self._test_pipeline_name,
        pipeline_root='test_pipeline_root',
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=[example_gen, statistics_gen],
    )

    self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
    self._tfx_ir = pipeline_pb2.Pipeline()
    with dsl.Pipeline('test_pipeline'):
      self.example_gen = base_component.BaseComponent(
          component=example_gen,
          component_launcher_class=in_process_component_launcher
          .InProcessComponentLauncher,
          depends_on=set(),
          pipeline=pipeline,
          pipeline_name=self._test_pipeline_name,
          pipeline_root=test_pipeline_root,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
          component_config=None,
          tfx_ir=self._tfx_ir)
      self.statistics_gen = base_component.BaseComponent(
          component=statistics_gen,
          component_launcher_class=in_process_component_launcher
          .InProcessComponentLauncher,
          depends_on=set(),
          pipeline=pipeline,
          pipeline_name=self._test_pipeline_name,
          pipeline_root=test_pipeline_root,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
          component_config=None,
          tfx_ir=self._tfx_ir
      )

    self.tfx_example_gen = example_gen
    self.tfx_statistics_gen = statistics_gen
Example #3
0
  def setUp(self):
    super(DriverTest, self).setUp()
    # Create input splits.
    test_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)
    self._input_base_path = os.path.join(test_dir, 'input_base')
    tf.io.gfile.makedirs(self._input_base_path)

    # Mock metadata.
    self._mock_metadata = tf.test.mock.Mock()
    self._example_gen_driver = driver.Driver(self._mock_metadata)

    # Create input dict.
    input_base = standard_artifacts.ExternalArtifact()
    input_base.uri = self._input_base_path
    self._input_channels = {
        'input_base': channel_utils.as_channel([input_base])
    }
    # Create exec proterties.
    self._exec_properties = {
        'input_config':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='s1', pattern='span{SPAN}/split1/*'),
                    example_gen_pb2.Input.Split(
                        name='s2', pattern='span{SPAN}/split2/*')
                ])),
    }
Example #4
0
  def setUp(self):
    super(BaseComponentTest, self).setUp()
    examples = standard_artifacts.ExternalArtifact()
    example_gen = csv_example_gen_component.CsvExampleGen(
        input=channel_utils.as_channel([examples]))
    statistics_gen = statistics_gen_component.StatisticsGen(
        examples=example_gen.outputs['examples'], instance_name='foo')

    pipeline = tfx_pipeline.Pipeline(
        pipeline_name=self._test_pipeline_name,
        pipeline_root='test_pipeline_root',
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=[example_gen, statistics_gen],
    )

    test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')

    self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
    with dsl.Pipeline('test_pipeline'):
      self.component = base_component.BaseComponent(
          component=statistics_gen,
          component_launcher_class=in_process_component_launcher
          .InProcessComponentLauncher,
          depends_on=set(),
          pipeline=pipeline,
          pipeline_name=self._test_pipeline_name,
          pipeline_root=test_pipeline_root,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
          component_config=None,
      )
    self.tfx_component = statistics_gen
Example #5
0
  def setUp(self):
    super(BaseComponentTest, self).setUp()
    examples = standard_artifacts.ExternalArtifact()
    example_gen = csv_example_gen_component.CsvExampleGen(
        input_base=channel_utils.as_channel([examples]))
    statistics_gen = statistics_gen_component.StatisticsGen(
        input_data=example_gen.outputs.examples, instance_name='foo')

    pipeline = tfx_pipeline.Pipeline(
        pipeline_name='test_pipeline',
        pipeline_root='test_pipeline_root',
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=[example_gen, statistics_gen],
    )

    self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
    with dsl.Pipeline('test_pipeline'):
      self.component = base_component.BaseComponent(
          component=statistics_gen,
          depends_on=set(),
          pipeline=pipeline,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
      )
    self.tfx_component = statistics_gen
  def testCsvExampleGenWrapper(self):
    input_base = standard_artifacts.ExternalArtifact(split='')
    input_base.uri = '/path/to/dataset'

    with patch.object(executor, 'Executor', autospec=True) as _:
      wrapper = executor_wrappers.CsvExampleGenWrapper(
          argparse.Namespace(
              exec_properties=json.dumps(self.exec_properties),
              outputs=artifact_utils.jsonify_artifact_dict(
                  {'examples': self.examples}),
              executor_class_path=(
                  'tfx.components.example_gen.csv_example_gen.executor.Executor'
              ),
              input_base=json.dumps([input_base.json_dict()])),)
      wrapper.run(output_basedir=self.output_basedir)

      # TODO(b/133011207): Validate arguments for executor and Do() method.

      metadata_file = os.path.join(
          self.output_basedir, 'output/ml_metadata/examples')

      expected_output_examples = standard_artifacts.Examples(split='dummy')
      # Expect that span and path are resolved.
      expected_output_examples.span = 1
      expected_output_examples.uri = (
          '/path/to/output/csv_example_gen/examples/mock_workflow_id/dummy/')

      with tf.gfile.GFile(metadata_file) as f:
        self.assertEqual(
            [expected_output_examples.json_dict()], json.loads(f.read()))
Example #7
0
  def setUp(self):
    input_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'testdata')

    # Create input dict.
    input_base = standard_artifacts.ExternalArtifact()
    input_base.uri = os.path.join(input_data_dir, 'external')
    self._input_dict = {'input_base': [input_base]}
Example #8
0
 def testEnableCache(self):
     input_base = standard_artifacts.ExternalArtifact()
     import_example_gen_1 = component.ImportExampleGen(
         input=channel_utils.as_channel([input_base]))
     self.assertEqual(None, import_example_gen_1.enable_cache)
     import_example_gen_2 = component.ImportExampleGen(
         input=channel_utils.as_channel([input_base]), enable_cache=True)
     self.assertEqual(True, import_example_gen_2.enable_cache)
Example #9
0
 def testConstruct(self):
     input_base = standard_artifacts.ExternalArtifact()
     import_example_gen = component.ImportExampleGen(
         input=channel_utils.as_channel([input_base]))
     self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                      import_example_gen.outputs['examples'].type_name)
     artifact_collection = import_example_gen.outputs['examples'].get()
     self.assertEqual('train', artifact_collection[0].split)
     self.assertEqual('eval', artifact_collection[1].split)
Example #10
0
  def setUp(self):
    super(ExecutorTest, self).setUp()
    input_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'testdata')

    # Create input dict.
    input_base = standard_artifacts.ExternalArtifact()
    input_base.uri = os.path.join(input_data_dir, 'external')
    self._input_dict = {INPUT_KEY: [input_base]}
Example #11
0
 def testConstruct(self):
     input_base = standard_artifacts.ExternalArtifact()
     csv_example_gen = component.CsvExampleGen(
         input_base=channel_utils.as_channel([input_base]))
     self.assertEqual('ExamplesPath',
                      csv_example_gen.outputs['examples'].type_name)
     artifact_collection = csv_example_gen.outputs['examples'].get()
     self.assertEqual('train', artifact_collection[0].split)
     self.assertEqual('eval', artifact_collection[1].split)
Example #12
0
 def testConstruct(self):
   input_base = standard_artifacts.ExternalArtifact()
   csv_example_gen = component.CsvExampleGen(
       input=channel_utils.as_channel([input_base]))
   self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                    csv_example_gen.outputs['examples'].type_name)
   artifact_collection = csv_example_gen.outputs['examples'].get()
   self.assertEqual(1, len(artifact_collection))
   self.assertEqual(['train', 'eval'],
                    artifact_utils.decode_split_names(
                        artifact_collection[0].split_names))
Example #13
0
 def testConstructCustomExecutor(self):
   input_base = standard_artifacts.ExternalArtifact()
   example_gen = component.FileBasedExampleGen(
       input_base=channel_utils.as_channel([input_base]),
       custom_executor_spec=executor_spec.ExecutorClassSpec(
           TestExampleGenExecutor))
   self.assertEqual(driver.Driver, example_gen.driver_class)
   self.assertEqual('ExamplesPath', example_gen.outputs['examples'].type_name)
   artifact_collection = example_gen.outputs['examples'].get()
   self.assertEqual('train', artifact_collection[0].split)
   self.assertEqual('eval', artifact_collection[1].split)
Example #14
0
 def testConstructSubclassFileBased(self):
   input_base = standard_artifacts.ExternalArtifact()
   example_gen = TestFileBasedExampleGenComponent(
       input_base=channel_utils.as_channel([input_base]))
   self.assertIn('input_base', example_gen.inputs.get_all())
   self.assertEqual(driver.Driver, example_gen.driver_class)
   self.assertEqual('ExamplesPath', example_gen.outputs['examples'].type_name)
   self.assertIsNone(example_gen.exec_properties.get('custom_config'))
   artifact_collection = example_gen.outputs['examples'].get()
   self.assertEqual('train', artifact_collection[0].split)
   self.assertEqual('eval', artifact_collection[1].split)
Example #15
0
    def setUp(self):
        super(BaseComponentWithPipelineParamTest, self).setUp()

        test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')
        example_gen_output_name = runtime_string_parameter.RuntimeStringParameter(
            name='example-gen-output-name', default='default-to-be-discarded')

        examples = standard_artifacts.ExternalArtifact()
        example_gen = csv_example_gen_component.CsvExampleGen(
            input=channel_utils.as_channel([examples]),
            output_config=example_gen_pb2.Output(
                split_config=example_gen_pb2.SplitConfig(splits=[
                    example_gen_pb2.SplitConfig.Split(
                        name=example_gen_output_name, hash_buckets=10)
                ])))
        statistics_gen = statistics_gen_component.StatisticsGen(
            examples=example_gen.outputs['examples'], instance_name='foo')

        pipeline = tfx_pipeline.Pipeline(
            pipeline_name=self._test_pipeline_name,
            pipeline_root='test_pipeline_root',
            metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
            components=[example_gen, statistics_gen],
        )

        self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
        self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
        with dsl.Pipeline('test_pipeline'):
            self.example_gen = base_component.BaseComponent(
                component=example_gen,
                component_launcher_class=in_process_component_launcher.
                InProcessComponentLauncher,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_name=self._test_pipeline_name,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                component_config=None)
            self.statistics_gen = base_component.BaseComponent(
                component=statistics_gen,
                component_launcher_class=in_process_component_launcher.
                InProcessComponentLauncher,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_name=self._test_pipeline_name,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                component_config=None,
            )

        self.tfx_example_gen = example_gen
        self.tfx_statistics_gen = statistics_gen
Example #16
0
def external_input(uri: Text) -> types.Channel:
    """Helper function to declare external input.

  Args:
    uri: external path.

  Returns:
    input channel.
  """
    instance = standard_artifacts.ExternalArtifact()
    instance.uri = uri
    return channel_utils.as_channel([instance])
Example #17
0
def external_input(uri: Any) -> types.Channel:
    """Helper function to declare external input.

  Args:
    uri: external path, can be RuntimeParameter

  Returns:
    input channel.
  """
    instance = standard_artifacts.ExternalArtifact()
    instance.uri = str(uri)
    return channel_utils.as_channel([instance])
Example #18
0
  def testConstructWithCustomConfig(self):
    input_base = standard_artifacts.ExternalArtifact()
    custom_config = example_gen_pb2.CustomConfig(custom_config=any_pb2.Any())
    example_gen = component.FileBasedExampleGen(
        input_base=channel_utils.as_channel([input_base]),
        custom_config=custom_config,
        custom_executor_spec=executor_spec.ExecutorClassSpec(
            TestExampleGenExecutor))

    stored_custom_config = example_gen_pb2.CustomConfig()
    json_format.Parse(example_gen.exec_properties['custom_config'],
                      stored_custom_config)
    self.assertEqual(custom_config, stored_custom_config)
Example #19
0
 def testConstructWithInputConfig(self):
   input_base = standard_artifacts.ExternalArtifact()
   example_gen = TestFileBasedExampleGenComponent(
       input_base=channel_utils.as_channel([input_base]),
       input_config=example_gen_pb2.Input(splits=[
           example_gen_pb2.Input.Split(name='train', pattern='train/*'),
           example_gen_pb2.Input.Split(name='eval', pattern='eval/*'),
           example_gen_pb2.Input.Split(name='test', pattern='test/*')
       ]))
   self.assertEqual('ExamplesPath', example_gen.outputs['examples'].type_name)
   artifact_collection = example_gen.outputs['examples'].get()
   self.assertEqual('train', artifact_collection[0].split)
   self.assertEqual('eval', artifact_collection[1].split)
   self.assertEqual('test', artifact_collection[2].split)
Example #20
0
 def testConstructWithInputConfig(self):
     input_base = standard_artifacts.ExternalArtifact()
     example_gen = TestFileBasedExampleGenComponent(
         input=channel_utils.as_channel([input_base]),
         input_config=example_gen_pb2.Input(splits=[
             example_gen_pb2.Input.Split(name='train', pattern='train/*'),
             example_gen_pb2.Input.Split(name='eval', pattern='eval/*'),
             example_gen_pb2.Input.Split(name='test', pattern='test/*')
         ]))
     self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                      example_gen.outputs['examples'].type_name)
     artifact_collection = example_gen.outputs['examples'].get()
     self.assertEqual(1, len(artifact_collection))
     self.assertEqual(['train', 'eval', 'test'],
                      artifact_utils.decode_split_names(
                          artifact_collection[0].split_names))
Example #21
0
 def testConstructWithOutputConfig(self):
     input_base = standard_artifacts.ExternalArtifact()
     example_gen = TestFileBasedExampleGenComponent(
         input_base=channel_utils.as_channel([input_base]),
         output_config=example_gen_pb2.
         Output(split_config=example_gen_pb2.SplitConfig(splits=[
             example_gen_pb2.SplitConfig.Split(name='train',
                                               hash_buckets=2),
             example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1),
             example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1)
         ])))
     self.assertEqual('ExamplesPath',
                      example_gen.outputs.examples.type_name)
     artifact_collection = example_gen.outputs.examples.get()
     self.assertEqual('train', artifact_collection[0].split)
     self.assertEqual('eval', artifact_collection[1].split)
     self.assertEqual('test', artifact_collection[2].split)
Example #22
0
 def testEnableCache(self):
     input_base = standard_artifacts.ExternalArtifact()
     custom_config = example_gen_pb2.CustomConfig(
         custom_config=any_pb2.Any())
     example_gen_1 = component.FileBasedExampleGen(
         input=channel_utils.as_channel([input_base]),
         custom_config=custom_config,
         custom_executor_spec=executor_spec.ExecutorClassSpec(
             TestExampleGenExecutor))
     self.assertEqual(None, example_gen_1.enable_cache)
     example_gen_2 = component.FileBasedExampleGen(
         input=channel_utils.as_channel([input_base]),
         custom_config=custom_config,
         custom_executor_spec=executor_spec.ExecutorClassSpec(
             TestExampleGenExecutor),
         enable_cache=True)
     self.assertEqual(True, example_gen_2.enable_cache)
Example #23
0
    def __init__(self,
                 input_data: types.Channel = None,
                 output_data: Optional[types.Channel] = None,
                 instance_name: Optional[Text] = None):
        """Construct a HelloComponent.
    Args:
      input_data: A Channel of type `standard_artifacts.InferenceResult`.
      output_data: A Channel of type `standard_artifacts.ExternalArtifact`.
      instance_name: Optional unique name. Necessary if multiple Hello components are
        declared in the same pipeline.
    """

        if not output_data:
            examples_artifact = standard_artifacts.ExternalArtifact()
            output_data = channel_utils.as_channel([examples_artifact])

        spec = HelloComponentSpec(input_data=input_data,
                                  output_data=output_data)
        super(HelloComponent, self).__init__(spec=spec,
                                             instance_name=instance_name)
Example #24
0
def CsvExampleGen_GCS(  #
    # Inputs
    #input_base_path: InputPath('ExternalPath'),
    input_base_path:
    'ExternalPath',  # A Channel of 'ExternalPath' type, which includes one artifact whose uri is an external directory with csv files inside (required).

    # Outputs
    #example_artifacts_path: OutputPath('ExamplesPath'),
    example_artifacts_path: 'ExamplesPath',

    # Execution properties
    #input_config_splits: {'List' : {'item_type': 'ExampleGen.Input.Split'}},
    input_config:
    'ExampleGen.Input' = None,  # = '{"splits": []}', # JSON-serialized example_gen_pb2.Input instance, providing input configuration. If unset, the files under input_base will be treated as a single split.
    #output_config_splits: {'List' : {'item_type': 'ExampleGen.SplitConfig'}},
    output_config:
    'ExampleGen.Output' = None,  # = '{"splitConfig": {"splits": []}}', # JSON-serialized example_gen_pb2.Output instance, providing output configuration. If unset, default splits will be 'train' and 'eval' with size 2:1.
    #custom_config: 'ExampleGen.CustomConfig' = None,
) -> NamedTuple('Outputs', [
    ('example_artifacts', 'ExamplesPath'),
]):
    """Executes the CsvExampleGen component.

    Args:
      input_base: A Channel of 'ExternalPath' type, which includes one artifact
        whose uri is an external directory with csv files inside (required).
      input_config: An example_gen_pb2.Input instance, providing input
        configuration. If unset, the files under input_base will be treated as a
        single split.
      output_config: An example_gen_pb2.Output instance, providing output
        configuration. If unset, default splits will be 'train' and 'eval' with
        size 2:1.
      ??? input: Forwards compatibility alias for the 'input_base' argument.
    Returns:
      example_artifacts: Artifact of type 'ExamplesPath' for output train and
        eval examples.
    """

    import json
    import os
    from google.protobuf import json_format
    from tfx.components.example_gen.csv_example_gen.component import CsvExampleGen
    from tfx.proto import example_gen_pb2
    from tfx.types import standard_artifacts
    from tfx.types import channel_utils

    # Create input dict.
    input_base = standard_artifacts.ExternalArtifact()
    input_base.uri = input_base_path
    input_base_channel = channel_utils.as_channel([input_base])

    input_config_obj = None
    if input_config:
        input_config_obj = example_gen_pb2.Input()
        json_format.Parse(input_config, input_config_obj)

    output_config_obj = None
    if output_config:
        output_config_obj = example_gen_pb2.Output()
        json_format.Parse(output_config, output_config_obj)

    component_class_instance = CsvExampleGen(
        input=input_base_channel,
        input_config=input_config_obj,
        output_config=output_config_obj,
    )

    # component_class_instance.inputs/outputs are wrappers that do not behave like real dictionaries. The underlying dict can be accessed using .get_all()
    # Channel artifacts can be accessed by calling .get()
    input_dict = {
        name: channel.get()
        for name, channel in component_class_instance.inputs.get_all().items()
    }
    output_dict = {
        name: channel.get()
        for name, channel in
        component_class_instance.outputs.get_all().items()
    }
    exec_properties = component_class_instance.exec_properties

    # Generating paths for output artifacts
    for output_artifact in output_dict['examples']:
        output_artifact.uri = example_artifacts_path
        if output_artifact.split:
            output_artifact.uri = os.path.join(output_artifact.uri,
                                               output_artifact.split)

    print('component instance: ' + str(component_class_instance))

    executor = CsvExampleGen.EXECUTOR_SPEC.executor_class()
    executor.Do(
        input_dict=input_dict,
        output_dict=output_dict,
        exec_properties=exec_properties,
    )

    return (example_artifacts_path, )
Example #25
0
    def test_prepare_input_for_processing(self):
        # Create input splits.
        test_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        input_base_path = os.path.join(test_dir, 'input_base')
        split1 = os.path.join(input_base_path, 'split1', 'data')
        io_utils.write_string_file(split1, 'testing')
        os.utime(split1, (0, 1))
        split2 = os.path.join(input_base_path, 'split2', 'data')
        io_utils.write_string_file(split2, 'testing2')
        os.utime(split2, (0, 3))

        # Mock metadata.
        mock_metadata = tf.test.mock.Mock()
        example_gen_driver = driver.Driver(mock_metadata)

        # Mock artifact.
        artifacts = []
        for i in [4, 3, 2, 1]:
            artifact = metadata_store_pb2.Artifact()
            artifact.id = i
            artifact.uri = input_base_path
            # Only odd ids will be matched
            if i % 2 == 1:
                artifact.custom_properties[
                    'input_fingerprint'].string_value = 'split:s1,num_files:1,total_bytes:7,xor_checksum:1,sum_checksum:1\nsplit:s2,num_files:1,total_bytes:8,xor_checksum:3,sum_checksum:3'
            else:
                artifact.custom_properties[
                    'input_fingerprint'].string_value = 'not_match'
            artifacts.append(artifact)

        # Create input dict.
        input_base = standard_artifacts.ExternalArtifact()
        input_base.uri = input_base_path
        input_dict = {'input_base': [input_base]}
        # Create exec proterties.
        exec_properties = {
            'input_config':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(name='s1', pattern='split1/*'),
                    example_gen_pb2.Input.Split(name='s2', pattern='split2/*')
                ])),
        }

        # Cache not hit.
        mock_metadata.get_artifacts_by_uri.return_value = [artifacts[0]]
        mock_metadata.publish_artifacts.return_value = [artifacts[3]]
        updated_input_dict = example_gen_driver._prepare_input_for_processing(
            copy.deepcopy(input_dict), exec_properties)
        self.assertEqual(1, len(updated_input_dict))
        self.assertEqual(1, len(updated_input_dict['input_base']))
        updated_input_base = updated_input_dict['input_base'][0]
        self.assertEqual(1, updated_input_base.id)
        self.assertEqual(input_base_path, updated_input_base.uri)

        # Cache hit.
        mock_metadata.get_artifacts_by_uri.return_value = artifacts
        mock_metadata.publish_artifacts.return_value = []
        updated_input_dict = example_gen_driver._prepare_input_for_processing(
            copy.deepcopy(input_dict), exec_properties)
        self.assertEqual(1, len(updated_input_dict))
        self.assertEqual(1, len(updated_input_dict['input_base']))
        updated_input_base = updated_input_dict['input_base'][0]
        self.assertEqual(3, updated_input_base.id)
        self.assertEqual(input_base_path, updated_input_base.uri)