def test_construct(self): input_base = types.TfxArtifact(type_name='ExternalPath') example_gen = component.ExampleGen( executor=TestExampleGenExecutor, input_base=channel.as_channel([input_base])) self.assertIn('input-base', example_gen.input_dict) self.assertEqual(driver.Driver, example_gen.driver) self.assertEqual('ExamplesPath', example_gen.outputs.examples.type_name) artifact_collection = example_gen.outputs.examples.get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split)
def test_construct_without_input_base(self): example_gen = component.ExampleGen( executor=TestExampleGenExecutor, input_config=example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='single', pattern='query'), ])) self.assertEqual({}, example_gen.input_dict) self.assertEqual(base_driver.BaseDriver, example_gen.driver) self.assertEqual('ExamplesPath', example_gen.outputs.examples.type_name) artifact_collection = example_gen.outputs.examples.get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split)
def test_construct_with_input_config(self): input_base = types.TfxArtifact(type_name='ExternalPath') example_gen = component.ExampleGen( executor=TestExampleGenExecutor, input_base=channel.as_channel([input_base]), input_config=example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='train', pattern='train/*'), example_gen_pb2.Input.Split(name='eval', pattern='eval/*'), example_gen_pb2.Input.Split(name='test', pattern='test/*') ])) self.assertEqual('ExamplesPath', example_gen.outputs.examples.type_name) artifact_collection = example_gen.outputs.examples.get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split) self.assertEqual('test', artifact_collection[2].split)
def test_construct_with_output_config(self): input_base = types.TfxType(type_name='ExternalPath') example_gen = component.ExampleGen( executor=TestExampleGenExecutor, input_base=channel.as_channel([input_base]), output_config=example_gen_pb2. Output(split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1), example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1) ]))) self.assertEqual('ExamplesPath', example_gen.outputs.examples.type_name) artifact_collection = example_gen.outputs.examples.get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split) self.assertEqual('test', artifact_collection[2].split)