def testConstructWithOutputConfig(self): big_query_example_gen = component.BigQueryExampleGen( query='query', output_config=example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1), ]))) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, big_query_example_gen.outputs['examples'].type_name)
def _testFeatureBasedPartition(self, partition_feature_name): self._exec_properties[ standard_component_specs.OUTPUT_CONFIG_KEY] = proto_utils.proto_to_json( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig( splits=[ example_gen_pb2.SplitConfig.Split( name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split( name='eval', hash_buckets=1) ], partition_feature_name=partition_feature_name)))
def testDo(self, mock_client): # Mock query result schema for _BigQueryConverter. mock_client.return_value.query.return_value.result.return_value.schema = self._schema output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. examples = standard_artifacts.Examples() examples.uri = output_data_dir output_dict = {'examples': [examples]} # Create exe properties. exec_properties = { 'input_config': proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='bq', pattern='SELECT i, b, f, s FROM `fake`'), ])), 'output_config': proto_utils.proto_to_json( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ]))) } # Run executor. big_query_example_gen = executor.Executor( base_executor.BaseExecutor.Context( beam_pipeline_args=['--project=test-project'])) big_query_example_gen.Do({}, output_dict, exec_properties) mock_client.assert_called_with(project='test-project') self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']), examples.split_names) # Check BigQuery example gen outputs. train_output_file = os.path.join(examples.uri, 'Split-train', 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(examples.uri, 'Split-eval', 'data_tfrecord-00000-of-00001.gz') self.assertTrue(fileio.exists(train_output_file)) self.assertTrue(fileio.exists(eval_output_file)) self.assertGreater( fileio.open(train_output_file).size(), fileio.open(eval_output_file).size())
def _testFeatureBasedPartition(self, partition_feature_name): self._exec_properties[ utils.OUTPUT_CONFIG_KEY] = json_format.MessageToJson( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig( splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ], partition_feature_name=partition_feature_name)))
def testDo(self, mock_client): # Mock query result schema for _BigQueryConverter. mock_client.return_value.query.return_value.result.return_value.schema = self._schema output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. examples = standard_artifacts.Examples() examples.uri = output_data_dir output_dict = {'examples': [examples]} # Create exe properties. exec_properties = { 'input_config': json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='bq', pattern='SELECT i, b, f, s FROM `fake`'), ]), preserving_proto_field_name=True), 'output_config': json_format.MessageToJson( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split( name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split( name='eval', hash_buckets=1) ])), preserving_proto_field_name=True) } # Run executor. big_query_example_gen = executor.Executor() big_query_example_gen.Do({}, output_dict, exec_properties) self.assertEqual( artifact_utils.encode_split_names(['train', 'eval']), examples.split_names) # Check BigQuery example gen outputs. train_output_file = os.path.join(examples.uri, 'train', 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(examples.uri, 'eval', 'data_tfrecord-00000-of-00001.gz') self.assertTrue(tf.io.gfile.exists(train_output_file)) self.assertTrue(tf.io.gfile.exists(eval_output_file)) self.assertGreater( tf.io.gfile.GFile(train_output_file).size(), tf.io.gfile.GFile(eval_output_file).size())
def make_default_output_config(input_config: example_gen_pb2.Input ) -> example_gen_pb2.Output: """Returns default output config based on input config.""" if len(input_config.splits) > 1: # Returns empty output split config as output split will be same as input. return example_gen_pb2.Output() else: # Returns 'train' and 'eval' splits with size 2:1. return example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ]))
def testDo(self): output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. examples = standard_artifacts.Examples() examples.uri = output_data_dir output_dict = {utils.EXAMPLES_KEY: [examples]} # Create exec proterties. exec_properties = { utils.INPUT_BASE_KEY: self._input_data_dir, utils.INPUT_CONFIG_KEY: json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='avro', pattern='avro/*.avro'), ]), preserving_proto_field_name=True), utils.OUTPUT_CONFIG_KEY: json_format.MessageToJson( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split( name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split( name='eval', hash_buckets=1) ])), preserving_proto_field_name=True) } # Run executor. avro_example_gen = avro_executor.Executor() avro_example_gen.Do({}, output_dict, exec_properties) self.assertEqual( artifact_utils.encode_split_names(['train', 'eval']), examples.split_names) # Check Avro example gen outputs. train_output_file = os.path.join(examples.uri, 'train', 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(examples.uri, 'eval', 'data_tfrecord-00000-of-00001.gz') self.assertTrue(fileio.exists(train_output_file)) self.assertTrue(fileio.exists(eval_output_file)) self.assertGreater( fileio.open(train_output_file).size(), fileio.open(eval_output_file).size())
def testDo(self, mock_client): # Mock query result schema for _BigQueryConverter. mock_client.return_value.query.return_value.result.return_value.schema = self._schema output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. train_examples = types.TfxArtifact(type_name='ExamplesPath', split='train') train_examples.uri = os.path.join(output_data_dir, 'train') eval_examples = types.TfxArtifact(type_name='ExamplesPath', split='eval') eval_examples.uri = os.path.join(output_data_dir, 'eval') output_dict = {'examples': [train_examples, eval_examples]} # Create exe properties. exec_properties = { 'input': json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='bq', pattern='SELECT i, f, s FROM `fake`'), ])), 'output': json_format.MessageToJson( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ]))) } # Run executor. big_query_example_gen = executor.Executor() big_query_example_gen.Do({}, output_dict, exec_properties) # Check BigQuery example gen outputs. train_output_file = os.path.join(train_examples.uri, 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(eval_examples.uri, 'data_tfrecord-00000-of-00001.gz') self.assertTrue(tf.gfile.Exists(train_output_file)) self.assertTrue(tf.gfile.Exists(eval_output_file)) self.assertGreater( tf.gfile.GFile(train_output_file).size(), tf.gfile.GFile(eval_output_file).size())
def testDo(self): output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. train_examples = types.TfxArtifact(type_name='ExamplesPath', split='train') train_examples.uri = os.path.join(output_data_dir, 'train') eval_examples = types.TfxArtifact(type_name='ExamplesPath', split='eval') eval_examples.uri = os.path.join(output_data_dir, 'eval') output_dict = {'examples': [train_examples, eval_examples]} # Create exe properties. exec_properties = { 'input_config': json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='bq', pattern='SELECT i, f, s FROM `fake`'), ])), 'custom_config': json_format.MessageToJson(example_gen_pb2.CustomConfig()), 'output_config': json_format.MessageToJson( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ]))), } # Run executor. presto_example_gen = executor.Executor() presto_example_gen.Do({}, output_dict, exec_properties) # Check Presto example gen outputs. train_output_file = os.path.join(train_examples.uri, 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(eval_examples.uri, 'data_tfrecord-00000-of-00001.gz') self.assertTrue(tf.gfile.Exists(train_output_file)) self.assertTrue(tf.gfile.Exists(eval_output_file)) self.assertGreater( tf.gfile.GFile(train_output_file).size(), tf.gfile.GFile(eval_output_file).size())
def test_construct_with_output_config(self): input_base = types.TfxArtifact(type_name='ExternalPath') example_gen = TestFileBasedExampleGenComponent( input_base=channel.as_channel([input_base]), output_config=example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1), example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1) ]))) self.assertEqual('ExamplesPath', example_gen.outputs.examples.type_name) artifact_collection = example_gen.outputs.examples.get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split) self.assertEqual('test', artifact_collection[2].split)
def testConstructWithOutputConfig(self): input_base = standard_artifacts.ExternalArtifact() example_gen = TestFileBasedExampleGenComponent( input_base=channel_utils.as_channel([input_base]), output_config=example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1), example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1) ]))) self.assertEqual('ExamplesPath', example_gen.outputs['examples'].type_name) artifact_collection = example_gen.outputs['examples'].get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split) self.assertEqual('test', artifact_collection[2].split)
def test_construct_with_output_config(self): big_query_example_gen = component.BigQueryExampleGen( query='', output_config=example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1), example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1) ]))) self.assertEqual('ExamplesPath', big_query_example_gen.outputs.examples.type_name) artifact_collection = big_query_example_gen.outputs.examples.get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split) self.assertEqual('test', artifact_collection[2].split)
def testConstructWithOutputConfig(self): big_query_to_elwc_example_gen = component.BigQueryToElwcExampleGen( query='query', elwc_config=elwc_config_pb2.ElwcConfig( context_feature_fields=['query_id', 'query_content']), output_config=example_gen_pb2. Output(split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1), example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1) ]))) self.assertEqual( standard_artifacts.Examples.TYPE_NAME, big_query_to_elwc_example_gen.outputs['examples'].type_name)
def testDo(self): output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. examples = standard_artifacts.Examples() examples.uri = output_data_dir output_dict = {'examples': [examples]} # Create exe properties. exec_properties = { 'input_config': proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='bq', pattern='SELECT i, f, s FROM `fake`'), ])), 'custom_config': proto_utils.proto_to_json(example_gen_pb2.CustomConfig()), 'output_config': proto_utils.proto_to_json( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ]))), } # Run executor. presto_example_gen = executor.Executor() presto_example_gen.Do({}, output_dict, exec_properties) self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']), examples.split_names) # Check Presto example gen outputs. train_output_file = os.path.join(examples.uri, 'Split-train', 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(examples.uri, 'Split-eval', 'data_tfrecord-00000-of-00001.gz') self.assertTrue(fileio.exists(train_output_file)) self.assertTrue(fileio.exists(eval_output_file)) self.assertGreater( fileio.open(train_output_file).size(), fileio.open(eval_output_file).size())
def testDo(self): output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. examples = standard_artifacts.Examples() examples.uri = output_data_dir output_dict = {standard_component_specs.EXAMPLES_KEY: [examples]} # Create exec proterties. exec_properties = { standard_component_specs.INPUT_BASE_KEY: self._input_data_dir, standard_component_specs.INPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='parquet', pattern='parquet/*'), ])), standard_component_specs.OUTPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ]))) } # Run executor. parquet_example_gen = parquet_executor.Executor() parquet_example_gen.Do({}, output_dict, exec_properties) self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']), examples.split_names) # Check Parquet example gen outputs. train_output_file = os.path.join(examples.uri, 'Split-train', 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(examples.uri, 'Split-eval', 'data_tfrecord-00000-of-00001.gz') self.assertTrue(fileio.exists(train_output_file)) self.assertTrue(fileio.exists(eval_output_file)) self.assertGreater( fileio.open(train_output_file).size(), fileio.open(eval_output_file).size())
def testConstructWithOutputConfig(self): big_query_example_gen = component.BigQueryExampleGen( query='query', output_config=example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1), example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1) ]))) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, big_query_example_gen.outputs['examples'].type_name) artifact_collection = big_query_example_gen.outputs['examples'].get() self.assertEqual(1, len(artifact_collection)) self.assertEqual(['train', 'eval', 'test'], artifact_utils.decode_split_names( artifact_collection[0].split_names))
def testConstructWithOutputConfig(self): example_gen = TestFileBasedExampleGenComponent( input_base='path', output_config=example_gen_pb2. Output(split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1), example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1) ]))) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, example_gen.outputs['examples'].type_name) artifact_collection = example_gen.outputs['examples'].get() self.assertEqual(1, len(artifact_collection)) self.assertEqual(['train', 'eval', 'test'], artifact_utils.decode_split_names( artifact_collection[0].split_names))
def setUp(self): super(ExecutorTest, self).setUp() self._input_data_dir = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'testdata', 'external') # Create values in exec_properties self._input_config = proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='tfrecord', pattern='tfrecord/*'), ])) self._output_config = proto_utils.proto_to_json( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ])))
def testConstructWithOutputConfig(self): output_config = example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1), example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1) ])) example_gen = TestFileBasedExampleGenComponent( input_base='path', output_config=output_config) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, example_gen.outputs['examples'].type_name) stored_output_config = example_gen_pb2.Output() json_format.Parse(example_gen.exec_properties['output_config'], stored_output_config) self.assertEqual(output_config, stored_output_config)
def testDo(self): output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. train_examples = standard_artifacts.Examples(split='train') train_examples.uri = os.path.join(output_data_dir, 'train') eval_examples = standard_artifacts.Examples(split='eval') eval_examples.uri = os.path.join(output_data_dir, 'eval') output_dict = {'examples': [train_examples, eval_examples]} # Create exec proterties. exec_properties = { 'input_config': json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='avro', pattern='avro/*.avro'), ])), 'output_config': json_format.MessageToJson( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ]))) } # Run executor. avro_example_gen = avro_executor.Executor() avro_example_gen.Do(self._input_dict, output_dict, exec_properties) # Check Avro example gen outputs. train_output_file = os.path.join(train_examples.uri, 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(eval_examples.uri, 'data_tfrecord-00000-of-00001.gz') self.assertTrue(tf.gfile.Exists(train_output_file)) self.assertTrue(tf.gfile.Exists(eval_output_file)) self.assertGreater( tf.gfile.GFile(train_output_file).size(), tf.gfile.GFile(eval_output_file).size())
def testDo(self): output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. examples = standard_artifacts.Examples() examples.uri = output_data_dir examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) output_dict = {'examples': [examples]} # Create exec proterties. exec_properties = { 'input_config': json_format.MessageToJson(example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='parquet', pattern='parquet/*'), ]), preserving_proto_field_name=True), 'output_config': json_format.MessageToJson(example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ])), preserving_proto_field_name=True) } # Run executor. parquet_example_gen = parquet_executor.Executor() parquet_example_gen.Do(self._input_dict, output_dict, exec_properties) # Check Parquet example gen outputs. train_output_file = os.path.join(examples.uri, 'train', 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(examples.uri, 'eval', 'data_tfrecord-00000-of-00001.gz') self.assertTrue(tf.io.gfile.exists(train_output_file)) self.assertTrue(tf.io.gfile.exists(eval_output_file)) self.assertGreater( tf.io.gfile.GFile(train_output_file).size(), tf.io.gfile.GFile(eval_output_file).size())
def setUp(self): super(ExampleGenComponentWithAvroExecutorTest, self).setUp() # Create input_base. input_data_dir = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'testdata') self.avro_dir_path = os.path.join(input_data_dir, 'external') # Create input_config. self.input_config = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='avro', pattern='avro/*.avro'), ]) # Create output_config. self.output_config = example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ]))
def testEmptyFeature(self): # Add output config to exec proterties. self._exec_properties['output_config'] = json_format.MessageToJson( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig( splits=[ example_gen_pb2.SplitConfig.Split( name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split( name='eval', hash_buckets=1) ], partition_feature_name='i'))) # Run executor. example_gen = TestExampleGenExecutor() with self.assertRaisesRegexp( RuntimeError, 'Partition feature does not contain any value.'): example_gen.Do({}, self._output_dict, self._exec_properties)
def testConstructWithOutputConfig(self): output_config = example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1), example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1) ])) example_gen = TestFileBasedExampleGenComponent( input_base='path', output_config=output_config) self.assertEqual( standard_artifacts.Examples.TYPE_NAME, example_gen.outputs[standard_component_specs.EXAMPLES_KEY].type_name) stored_output_config = example_gen_pb2.Output() proto_utils.json_to_proto( example_gen.exec_properties[standard_component_specs.OUTPUT_CONFIG_KEY], stored_output_config) self.assertEqual(output_config, stored_output_config)
def testInvalidFeatureName(self): # Add output config to exec proterties. self._exec_properties[ utils.OUTPUT_CONFIG_KEY] = json_format.MessageToJson( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig( splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ], partition_feature_name='invalid'))) # Run executor. example_gen = TestExampleGenExecutor() with self.assertRaisesRegexp(RuntimeError, 'Feature name `.*` does not exist.'): example_gen.Do({}, self._output_dict, self._exec_properties)
def testMakeOutputSplitNames(self): split_names = utils.generate_output_split_names( input_config=example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='train', pattern='train/*'), example_gen_pb2.Input.Split(name='eval', pattern='eval/*') ]), output_config=example_gen_pb2.Output()) self.assertListEqual(['train', 'eval'], split_names) split_names = utils.generate_output_split_names( input_config=example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='single', pattern='single/*') ]), output_config=example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ]))) self.assertListEqual(['train', 'eval'], split_names)
def make_default_output_config( input_config: Union[example_gen_pb2.Input, Dict[Text, Any]] ) -> example_gen_pb2.Output: """Returns default output config based on input config.""" if isinstance(input_config, example_gen_pb2.Input): input_config = json_format.MessageToDict( input_config, including_default_value_fields=True) if len(input_config['splits']) > 1: # Returns empty output split config as output split will be same as input. return example_gen_pb2.Output() else: # Returns 'train' and 'eval' splits with size 2:1. return example_gen_pb2.Output(split_config=example_gen_pb2.SplitConfig( splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ]))
def setUp(self): super().setUp() # Create input_base. input_data_dir = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'testdata') self.parquet_dir_path = os.path.join(input_data_dir, 'external') # Create input_config. self.input_config = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='parquet', pattern='parquet/*.parquet'), ]) # Create output_config. self.output_config = example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ]))
def testConstructWithOutputConfig(self): presto_example_gen = component.PrestoExampleGen( self.conn_config, query='query', output_config=example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1), example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1) ]))) self.assertEqual( self.conn_config, self._extract_conn_config( presto_example_gen.exec_properties['custom_config'])) self.assertEqual('ExamplesPath', presto_example_gen.outputs['examples'].type_name) artifact_collection = presto_example_gen.outputs['examples'].get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split) self.assertEqual('test', artifact_collection[2].split)
def testInvalidFloatListFeature(self): # Add output config to exec proterties. self._exec_properties['output_config'] = json_format.MessageToJson( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig( splits=[ example_gen_pb2.SplitConfig.Split( name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split( name='eval', hash_buckets=1) ], partition_feature_name='f'))) self._exec_properties['has_empty'] = False # Run executor. example_gen = TestExampleGenExecutor() with self.assertRaisesRegexp( RuntimeError, 'Only `bytes_list` and `int64_list` features are supported for partition.' ): example_gen.Do({}, self._output_dict, self._exec_properties)