Ejemplo n.º 1
0
 def testEnableCache(self):
     input_base = standard_artifacts.ExternalArtifact()
     custom_config = example_gen_pb2.CustomConfig(
         custom_config=any_pb2.Any())
     example_gen_1 = component.FileBasedExampleGen(
         input=channel_utils.as_channel([input_base]),
         custom_config=custom_config,
         custom_executor_spec=executor_spec.ExecutorClassSpec(
             TestExampleGenExecutor))
     self.assertEqual(None, example_gen_1.enable_cache)
     example_gen_2 = component.FileBasedExampleGen(
         input=channel_utils.as_channel([input_base]),
         custom_config=custom_config,
         custom_executor_spec=executor_spec.ExecutorClassSpec(
             TestExampleGenExecutor),
         enable_cache=True)
     self.assertEqual(True, example_gen_2.enable_cache)
Ejemplo n.º 2
0
 def testConstructCustomExecutor(self):
     example_gen = component.FileBasedExampleGen(
         input_base='path',
         custom_executor_spec=executor_spec.ExecutorClassSpec(
             TestExampleGenExecutor))
     self.assertEqual(driver.Driver, example_gen.driver_class)
     self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                      example_gen.outputs['examples'].type_name)
Ejemplo n.º 3
0
 def test_construct_custom_executor(self):
   input_base = types.TfxArtifact(type_name='ExternalPath')
   example_gen = component.FileBasedExampleGen(
       input_base=channel.as_channel([input_base]),
       executor_class=TestExampleGenExecutor)
   self.assertEqual(driver.Driver, example_gen.driver_class)
   self.assertEqual('ExamplesPath', example_gen.outputs.examples.type_name)
   artifact_collection = example_gen.outputs.examples.get()
   self.assertEqual('train', artifact_collection[0].split)
   self.assertEqual('eval', artifact_collection[1].split)
Ejemplo n.º 4
0
 def testConstructCustomExecutor(self):
   input_base = standard_artifacts.ExternalArtifact()
   example_gen = component.FileBasedExampleGen(
       input_base=channel_utils.as_channel([input_base]),
       custom_executor_spec=executor_spec.ExecutorClassSpec(
           TestExampleGenExecutor))
   self.assertEqual(driver.Driver, example_gen.driver_class)
   self.assertEqual('ExamplesPath', example_gen.outputs['examples'].type_name)
   artifact_collection = example_gen.outputs['examples'].get()
   self.assertEqual('train', artifact_collection[0].split)
   self.assertEqual('eval', artifact_collection[1].split)
Ejemplo n.º 5
0
  def testConstructWithCustomConfig(self):
    custom_config = example_gen_pb2.CustomConfig(custom_config=any_pb2.Any())
    example_gen = component.FileBasedExampleGen(
        input_base='path',
        custom_config=custom_config,
        custom_executor_spec=executor_spec.BeamExecutorSpec(
            TestExampleGenExecutor))

    stored_custom_config = example_gen_pb2.CustomConfig()
    proto_utils.json_to_proto(
        example_gen.exec_properties[standard_component_specs.CUSTOM_CONFIG_KEY],
        stored_custom_config)
    self.assertEqual(custom_config, stored_custom_config)
Ejemplo n.º 6
0
    def testConstructWithCustomConfig(self):
        custom_config = example_gen_pb2.CustomConfig(
            custom_config=any_pb2.Any())
        example_gen = component.FileBasedExampleGen(
            input_base='path',
            custom_config=custom_config,
            custom_executor_spec=executor_spec.ExecutorClassSpec(
                TestExampleGenExecutor))

        stored_custom_config = example_gen_pb2.CustomConfig()
        json_format.Parse(example_gen.exec_properties['custom_config'],
                          stored_custom_config)
        self.assertEqual(custom_config, stored_custom_config)
Ejemplo n.º 7
0
  def testConstructWithCustomConfig(self):
    input_base = standard_artifacts.ExternalArtifact()
    custom_config = example_gen_pb2.CustomConfig(custom_config=any_pb2.Any())
    example_gen = component.FileBasedExampleGen(
        input_base=channel_utils.as_channel([input_base]),
        custom_config=custom_config,
        custom_executor_spec=executor_spec.ExecutorClassSpec(
            TestExampleGenExecutor))

    stored_custom_config = example_gen_pb2.CustomConfig()
    json_format.Parse(example_gen.exec_properties['custom_config'],
                      stored_custom_config)
    self.assertEqual(custom_config, stored_custom_config)
Ejemplo n.º 8
0
 def testConstructCustomExecutor(self):
     example_gen = component.FileBasedExampleGen(
         input_base='path',
         custom_executor_spec=executor_spec.ExecutorClassSpec(
             TestExampleGenExecutor))
     self.assertEqual(driver.Driver, example_gen.driver_class)
     self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                      example_gen.outputs['examples'].type_name)
     artifact_collection = example_gen.outputs['examples'].get()
     self.assertEqual(1, len(artifact_collection))
     self.assertEqual(['train', 'eval'],
                      artifact_utils.decode_split_names(
                          artifact_collection[0].split_names))
Ejemplo n.º 9
0
 def testConstructWithStaticRangeConfig(self):
     range_config = range_config_pb2.RangeConfig(
         static_range=range_config_pb2.StaticRange(start_span_number=1,
                                                   end_span_number=1))
     example_gen = component.FileBasedExampleGen(
         input_base='path',
         range_config=range_config,
         custom_executor_spec=executor_spec.ExecutorClassSpec(
             TestExampleGenExecutor))
     stored_range_config = range_config_pb2.RangeConfig()
     json_format.Parse(example_gen.exec_properties['range_config'],
                       stored_range_config)
     self.assertEqual(range_config, stored_range_config)
Ejemplo n.º 10
0
    def test_construct_with_custom_config(self):
        input_base = types.TfxArtifact(type_name='ExternalPath')
        custom_config = example_gen_pb2.CustomConfig(
            custom_config=any_pb2.Any())
        example_gen = component.FileBasedExampleGen(
            input_base=channel.as_channel([input_base]),
            custom_config=custom_config,
            executor_class=TestExampleGenExecutor)

        stored_custom_config = example_gen_pb2.CustomConfig()
        json_format.Parse(example_gen.exec_properties['custom_config'],
                          stored_custom_config)
        self.assertEqual(custom_config, stored_custom_config)
Ejemplo n.º 11
0
 def testConstructWithStaticRangeConfig(self):
   range_config = range_config_pb2.RangeConfig(
       static_range=range_config_pb2.StaticRange(
           start_span_number=1, end_span_number=1))
   example_gen = component.FileBasedExampleGen(
       input_base='path',
       range_config=range_config,
       custom_executor_spec=executor_spec.BeamExecutorSpec(
           TestExampleGenExecutor))
   stored_range_config = range_config_pb2.RangeConfig()
   proto_utils.json_to_proto(
       example_gen.exec_properties[standard_component_specs.RANGE_CONFIG_KEY],
       stored_range_config)
   self.assertEqual(range_config, stored_range_config)