def testConstructSubclassQueryBased(self): example_gen = TestQueryBasedExampleGenComponent( input_config=example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='single', pattern='query'), ])) self.assertEqual({}, example_gen.inputs.get_all()) self.assertEqual(base_driver.BaseDriver, example_gen.driver_class) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, example_gen.outputs['examples'].type_name) self.assertIsNone(example_gen.exec_properties.get('custom_config')) artifact_collection = example_gen.outputs['examples'].get() self.assertEqual(1, len(artifact_collection)) self.assertEqual(['train', 'eval'], artifact_utils.decode_split_names( artifact_collection[0].split_names))
def testDoInputSplit(self): # Create exec proterties for input split. self._exec_properties = { utils.INPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='train', pattern='train/*'), example_gen_pb2.Input.Split(name='eval', pattern='eval/*') ])), utils.OUTPUT_CONFIG_KEY: proto_utils.proto_to_json(example_gen_pb2.Output()) } self._testDo()
def testDoInputSplit(self): # Create exec proterties for input split. self._exec_properties = { utils.INPUT_CONFIG_KEY: json_format.MessageToJson(example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='train', pattern='train/*'), example_gen_pb2.Input.Split(name='eval', pattern='eval/*') ]), preserving_proto_field_name=True), utils.OUTPUT_CONFIG_KEY: json_format.MessageToJson(example_gen_pb2.Output(), preserving_proto_field_name=True) } self._testDo()
def test_construct_with_input_config(self): input_base = types.TfxArtifact(type_name='ExternalPath') example_gen = component._FileBasedExampleGen( input_base=channel.as_channel([input_base]), input_config=example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='train', pattern='train/*'), example_gen_pb2.Input.Split(name='eval', pattern='eval/*'), example_gen_pb2.Input.Split(name='test', pattern='test/*') ])) self.assertEqual('ExamplesPath', example_gen.outputs.examples.type_name) artifact_collection = example_gen.outputs.examples.get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split) self.assertEqual('test', artifact_collection[2].split)
def testConstructWithInputConfig(self): input_base = standard_artifacts.ExternalArtifact() example_gen = TestFileBasedExampleGenComponent( input=channel_utils.as_channel([input_base]), input_config=example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='train', pattern='train/*'), example_gen_pb2.Input.Split(name='eval', pattern='eval/*'), example_gen_pb2.Input.Split(name='test', pattern='test/*') ])) self.assertEqual('ExamplesPath', example_gen.outputs['examples'].type_name) artifact_collection = example_gen.outputs['examples'].get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split) self.assertEqual('test', artifact_collection[2].split)
def testDo(self): output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. examples = standard_artifacts.Examples() examples.uri = output_data_dir output_dict = {utils.EXAMPLES_KEY: [examples]} # Create exec proterties. exec_properties = { utils.INPUT_BASE_KEY: self._input_data_dir, utils.INPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='avro', pattern='avro/*.avro'), ])), utils.OUTPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ]))) } # Run executor. avro_example_gen = avro_executor.Executor() avro_example_gen.Do({}, output_dict, exec_properties) self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']), examples.split_names) # Check Avro example gen outputs. train_output_file = os.path.join(examples.uri, 'train', 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(examples.uri, 'eval', 'data_tfrecord-00000-of-00001.gz') self.assertTrue(fileio.exists(train_output_file)) self.assertTrue(fileio.exists(eval_output_file)) self.assertGreater( fileio.open(train_output_file).size(), fileio.open(eval_output_file).size())
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text, module_file: Text, serving_model_dir: Text) -> pipeline.Pipeline: examples = external_input(data_root) input_split = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='train', pattern='iris_training.csv'), example_gen_pb2.Input.Split(name='eval', pattern='iris_test.csv') ]) example_gen = CsvExampleGen(input_base=examples, input_config=input_split) statistics_gen = StatisticsGen(input_data=example_gen.outputs.examples) infer_schema = SchemaGen(stats=statistics_gen.outputs.output) validate_stats = ExampleValidator( stats=statistics_gen.outputs.output, schema=infer_schema.outputs.output) transform = Transform( input_data=example_gen.outputs.examples, schema=infer_schema.outputs.output, module_file=module_file) trainer = Trainer( module_file=module_file, examples=transform.outputs.transformed_examples, schema=infer_schema.outputs.output, transform_output=transform.outputs.transform_output, train_args=trainer_pb2.TrainArgs(num_steps=1000), eval_args=trainer_pb2.EvalArgs(num_steps=500)) model_analyzer = Evaluator( examples=example_gen.outputs.examples, model_exports=trainer.outputs.output, feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(specs=[ evaluator_pb2.SingleSlicingSpec() ])) model_validator = ModelValidator( examples=example_gen.outputs.examples, model=trainer.outputs.output) pusher = Pusher( model_export=trainer.outputs.output, model_blessing=model_validator.outputs.blessing, push_destination=pusher_pb2.PushDestination( filesystem=pusher_pb2.PushDestination.Filesystem( base_directory=serving_model_dir))) return pipeline.Pipeline( pipeline_name=pipeline_name, pipeline_root=pipeline_root, components=[ example_gen, statistics_gen, infer_schema, validate_stats, transform, trainer, model_analyzer, model_validator, pusher ], log_root='/var/tmp/tfx/logs', )
def testComponentspecBasic(self): proto = example_gen_pb2.Input() proto.splits.extend([ example_gen_pb2.Input.Split(name='name1', pattern='pattern1'), example_gen_pb2.Input.Split(name='name2', pattern='pattern2'), example_gen_pb2.Input.Split(name='name3', pattern='pattern3'), ]) input_channel = Channel(type_name='InputType') output_channel = Channel(type_name='OutputType') spec = _BasicComponentSpec(folds=10, proto=proto, input=input_channel, output=output_channel) # Verify proto property. self.assertIsInstance(spec.exec_properties['proto'], str) decoded_proto = json.loads(spec.exec_properties['proto']) self.assertCountEqual(['splits'], decoded_proto.keys()) self.assertEqual(3, len(decoded_proto['splits'])) self.assertCountEqual(['name1', 'name2', 'name3'], list(s['name'] for s in decoded_proto['splits'])) self.assertCountEqual(['pattern1', 'pattern2', 'pattern3'], list(s['pattern'] for s in decoded_proto['splits'])) # Verify other properties. self.assertEqual(10, spec.exec_properties['folds']) self.assertIs(spec.inputs.input, input_channel) self.assertIs(spec.outputs.output, output_channel) with self.assertRaisesRegexp( TypeError, "Expected type <(class|type) 'int'> for parameter u?'folds' but got " 'string.'): spec = _BasicComponentSpec(folds='string', input=input_channel, output=output_channel) with self.assertRaisesRegexp(TypeError, 'Expected InputType but found WrongType'): spec = _BasicComponentSpec(folds=10, input=Channel(type_name='WrongType'), output=output_channel) with self.assertRaisesRegexp( TypeError, 'Expected OutputType but found WrongType'): spec = _BasicComponentSpec(folds=10, input=input_channel, output=Channel(type_name='WrongType'))
def testDo(self): output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. train_examples = standard_artifacts.Examples(split='train') train_examples.uri = os.path.join(output_data_dir, 'train') eval_examples = standard_artifacts.Examples(split='eval') eval_examples.uri = os.path.join(output_data_dir, 'eval') output_dict = {'examples': [train_examples, eval_examples]} # Create exe properties. exec_properties = { 'input_config': json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='bq', pattern='SELECT i, f, s FROM `fake`'), ]), preserving_proto_field_name=True), 'custom_config': json_format.MessageToJson(example_gen_pb2.CustomConfig()), 'output_config': json_format.MessageToJson( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split( name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split( name='eval', hash_buckets=1) ]))), } # Run executor. presto_example_gen = executor.Executor() presto_example_gen.Do({}, output_dict, exec_properties) # Check Presto example gen outputs. train_output_file = os.path.join(train_examples.uri, 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(eval_examples.uri, 'data_tfrecord-00000-of-00001.gz') self.assertTrue(tf.io.gfile.exists(train_output_file)) self.assertTrue(tf.io.gfile.exists(eval_output_file)) self.assertGreater( tf.io.gfile.GFile(train_output_file).size(), tf.io.gfile.GFile(eval_output_file).size())
def testDo(self): output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. examples = standard_artifacts.Examples() examples.uri = output_data_dir output_dict = {utils.EXAMPLES_KEY: [examples]} # Create exec proterties. exec_properties = { utils.INPUT_BASE_KEY: self._input_data_dir, utils.INPUT_CONFIG_KEY: json_format.MessageToJson(example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='parquet', pattern='parquet/*'), ]), preserving_proto_field_name=True), utils.OUTPUT_CONFIG_KEY: json_format.MessageToJson(example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ])), preserving_proto_field_name=True) } # Run executor. parquet_example_gen = parquet_executor.Executor() parquet_example_gen.Do({}, output_dict, exec_properties) self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']), examples.split_names) # Check Parquet example gen outputs. train_output_file = os.path.join(examples.uri, 'train', 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(examples.uri, 'eval', 'data_tfrecord-00000-of-00001.gz') self.assertTrue(tf.io.gfile.exists(train_output_file)) self.assertTrue(tf.io.gfile.exists(eval_output_file)) self.assertGreater( tf.io.gfile.GFile(train_output_file).size(), tf.io.gfile.GFile(eval_output_file).size())
def testDo(self, mock_client): # Mock query result schema for _BigQueryConverter. mock_client.return_value.query.return_value.result.return_value.schema = self._schema output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. train_examples = standard_artifacts.Examples(split='train') train_examples.uri = os.path.join(output_data_dir, 'train') eval_examples = standard_artifacts.Examples(split='eval') eval_examples.uri = os.path.join(output_data_dir, 'eval') output_dict = {'examples': [train_examples, eval_examples]} # Create exe properties. exec_properties = { 'input_config': json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='bq', pattern='SELECT i, f, s FROM `fake`'), ])), 'output_config': json_format.MessageToJson( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ]))) } # Run executor. big_query_example_gen = executor.Executor() big_query_example_gen.Do({}, output_dict, exec_properties) # Check BigQuery example gen outputs. train_output_file = os.path.join(train_examples.uri, 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(eval_examples.uri, 'data_tfrecord-00000-of-00001.gz') self.assertTrue(tf.gfile.Exists(train_output_file)) self.assertTrue(tf.gfile.Exists(eval_output_file)) self.assertGreater( tf.gfile.GFile(train_output_file).size(), tf.gfile.GFile(eval_output_file).size())
def testComponentspecBasic(self): proto = example_gen_pb2.Input() proto.splits.extend([ example_gen_pb2.Input.Split(name='name1', pattern='pattern1'), example_gen_pb2.Input.Split(name='name2', pattern='pattern2'), example_gen_pb2.Input.Split(name='name3', pattern='pattern3'), ]) input_channel = Channel(type=_InputArtifact) output_channel = Channel(type=_OutputArtifact) spec = _BasicComponentSpec( folds=10, proto=proto, input=input_channel, output=output_channel) # Verify proto property. self.assertIsInstance(spec.exec_properties['proto'], str) decoded_proto = json.loads(spec.exec_properties['proto']) self.assertCountEqual(['splits'], decoded_proto.keys()) self.assertEqual(3, len(decoded_proto['splits'])) self.assertCountEqual(['name1', 'name2', 'name3'], list(s['name'] for s in decoded_proto['splits'])) self.assertCountEqual(['pattern1', 'pattern2', 'pattern3'], list(s['pattern'] for s in decoded_proto['splits'])) # Verify other properties. self.assertEqual(10, spec.exec_properties['folds']) self.assertIs(spec.inputs['input'], input_channel) self.assertIs(spec.outputs['output'], output_channel) # Verify compatibility aliasing behavior. self.assertIs(spec.inputs['future_input_name'], spec.inputs['input']) self.assertIs(spec.outputs['future_output_name'], spec.outputs['output']) with self.assertRaisesRegexp( TypeError, "Expected type <(class|type) 'int'> for parameter u?'folds' but got " 'string.'): spec = _BasicComponentSpec( folds='string', input=input_channel, output=output_channel) with self.assertRaisesRegexp( TypeError, '.*should be a Channel of .*InputArtifact.*got (.|\\s)*Examples.*'): spec = _BasicComponentSpec( folds=10, input=Channel(type=Examples), output=output_channel) with self.assertRaisesRegexp( TypeError, '.*should be a Channel of .*OutputArtifact.*got (.|\\s)*Examples.*'): spec = _BasicComponentSpec( folds=10, input=input_channel, output=Channel(type=Examples))
def testDo(self): output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. examples = standard_artifacts.Examples() examples.uri = output_data_dir output_dict = {'examples': [examples]} # Create exe properties. exec_properties = { 'input_config': proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='bq', pattern='SELECT i, f, s FROM `fake`'), ])), 'custom_config': proto_utils.proto_to_json(example_gen_pb2.CustomConfig()), 'output_config': proto_utils.proto_to_json( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ]))), } # Run executor. presto_example_gen = executor.Executor() presto_example_gen.Do({}, output_dict, exec_properties) self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']), examples.split_names) # Check Presto example gen outputs. train_output_file = os.path.join(examples.uri, 'Split-train', 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(examples.uri, 'Split-eval', 'data_tfrecord-00000-of-00001.gz') self.assertTrue(fileio.exists(train_output_file)) self.assertTrue(fileio.exists(eval_output_file)) self.assertGreater( fileio.open(train_output_file).size(), fileio.open(eval_output_file).size())
def testConstructWithInputConfig(self): input_base = standard_artifacts.ExternalArtifact() example_gen = TestFileBasedExampleGenComponent( input=channel_utils.as_channel([input_base]), input_config=example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='train', pattern='train/*'), example_gen_pb2.Input.Split(name='eval', pattern='eval/*'), example_gen_pb2.Input.Split(name='test', pattern='test/*') ])) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, example_gen.outputs['examples'].type_name) artifact_collection = example_gen.outputs['examples'].get() self.assertEqual(1, len(artifact_collection)) self.assertEqual(['train', 'eval', 'test'], artifact_utils.decode_split_names( artifact_collection[0].split_names))
def testConstructWithInputConfig(self): big_query_elwc_example_gen = component.BigQueryElwcExampleGen( elwc_config=example_gen_pb2.ElwcConfig( context_feature_fields=['query_id', 'query_content']), input_config=example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='train', pattern='query1'), example_gen_pb2.Input.Split(name='eval', pattern='query2'), example_gen_pb2.Input.Split(name='test', pattern='query3') ])) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, big_query_elwc_example_gen.outputs['examples'].type_name) artifact_collection = big_query_elwc_example_gen.outputs['examples'].get() self.assertEqual(1, len(artifact_collection)) self.assertEqual(['train', 'eval', 'test'], artifact_utils.decode_split_names( artifact_collection[0].split_names))
def _make_example_gen(self) -> base_component.BaseComponent: """Returns a TFX ExampleGen which produces the desired split.""" splits = [] for name, value in self._dataset_builder.info.splits.items(): # Assume there is only one file per split. # Filename will be like `'fashion_mnist-test.tfrecord-00000-of-00001'`. assert len(value.filenames) == 1 pattern = value.filenames[0] splits.append(example_gen_pb2.Input.Split(name=name, pattern=pattern)) logging.info('Splits: %s', splits) input_config = example_gen_pb2.Input(splits=splits) return tfx.ImportExampleGen( input=external_input(self._dataset_builder.data_dir), input_config=input_config)
def resolve_exec_properties( self, exec_properties: Dict[Text, Any], pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> Dict[Text, Any]: """Overrides BaseDriver.resolve_exec_properties().""" del pipeline_info, component_info input_config = example_gen_pb2.Input() proto_utils.json_to_proto( exec_properties[standard_component_specs.INPUT_CONFIG_KEY], input_config) input_base = exec_properties[standard_component_specs.INPUT_BASE_KEY] logging.debug('Processing input %s.', input_base) range_config = None range_config_entry = exec_properties.get( standard_component_specs.RANGE_CONFIG_KEY) if range_config_entry: range_config = range_config_pb2.RangeConfig() proto_utils.json_to_proto(range_config_entry, range_config) if range_config.HasField('static_range'): # For ExampleGen, StaticRange must specify an exact span to look for, # since only one span is processed at a time. start_span_number = range_config.static_range.start_span_number end_span_number = range_config.static_range.end_span_number if start_span_number != end_span_number: raise ValueError( 'Start and end span numbers for RangeConfig.static_range must ' 'be equal: (%s, %s)' % (start_span_number, end_span_number)) # Note that this function updates the input_config.splits.pattern. fingerprint, span, version = utils.calculate_splits_fingerprint_span_and_version( input_base, input_config.splits, range_config) exec_properties[standard_component_specs. INPUT_CONFIG_KEY] = proto_utils.proto_to_json( input_config) exec_properties[utils.SPAN_PROPERTY_NAME] = span exec_properties[utils.VERSION_PROPERTY_NAME] = version exec_properties[utils.FINGERPRINT_PROPERTY_NAME] = fingerprint return exec_properties
def testDo(self): output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. train_examples = types.TfxArtifact(type_name='ExamplesPath', split='train') train_examples.uri = os.path.join(output_data_dir, 'train') eval_examples = types.TfxArtifact(type_name='ExamplesPath', split='eval') eval_examples.uri = os.path.join(output_data_dir, 'eval') output_dict = {'examples': [train_examples, eval_examples]} # Create exec proterties. exec_properties = { 'input_config': json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='parquet', pattern='parquet/*'), ])), 'output_config': json_format.MessageToJson( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ]))) } # Run executor. parquet_example_gen = parquet_executor.Executor() parquet_example_gen.Do(self._input_dict, output_dict, exec_properties) # Check Parquet example gen outputs. train_output_file = os.path.join(train_examples.uri, 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(eval_examples.uri, 'data_tfrecord-00000-of-00001.gz') self.assertTrue(tf.gfile.Exists(train_output_file)) self.assertTrue(tf.gfile.Exists(eval_output_file)) self.assertGreater( tf.gfile.GFile(train_output_file).size(), tf.gfile.GFile(eval_output_file).size())
def testConstructSubclassQueryBased(self): example_gen = TestQueryBasedExampleGenComponent( input_config=example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='single', pattern='query'), ])) self.assertEqual({}, example_gen.inputs.get_all()) self.assertEqual(driver.QueryBasedDriver, example_gen.driver_class) self.assertEqual( standard_artifacts.Examples.TYPE_NAME, example_gen.outputs[standard_component_specs.EXAMPLES_KEY].type_name) self.assertEqual( example_gen.exec_properties[ standard_component_specs.OUTPUT_DATA_FORMAT_KEY], example_gen_pb2.FORMAT_TF_EXAMPLE) self.assertIsNone( example_gen.exec_properties.get( standard_component_specs.CUSTOM_CONFIG_KEY))
def setUp(self): super(ExecutorTest, self).setUp() self._input_data_dir = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'testdata', 'external') # Create values in exec_properties self._input_config = proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='tfrecord', pattern='tfrecord/*'), ])) self._output_config = proto_utils.proto_to_json( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ])))
def testExecutionParameterTypeCheck(self): int_parameter = ExecutionParameter(type=int) int_parameter.type_check('int_parameter', 8) with self.assertRaisesRegex( TypeError, "Expected type <(class|type) 'int'>" " for parameter u?'int_parameter'"): int_parameter.type_check('int_parameter', 'string') list_parameter = ExecutionParameter(type=List[int]) list_parameter.type_check('list_parameter', []) list_parameter.type_check('list_parameter', [42]) with self.assertRaisesRegex(TypeError, 'Expecting a list for parameter'): list_parameter.type_check('list_parameter', 42) with self.assertRaisesRegex( TypeError, "Expecting item type <(class|type) " "'int'> for parameter u?'list_parameter'"): list_parameter.type_check('list_parameter', [42, 'wrong item']) dict_parameter = ExecutionParameter(type=Dict[str, int]) dict_parameter.type_check('dict_parameter', {}) dict_parameter.type_check('dict_parameter', {'key1': 1, 'key2': 2}) with self.assertRaisesRegex(TypeError, 'Expecting a dict for parameter'): dict_parameter.type_check('dict_parameter', 'simple string') with self.assertRaisesRegex( TypeError, "Expecting value type " "<(class|type) 'int'>"): dict_parameter.type_check('dict_parameter', {'key1': '1'}) proto_parameter = ExecutionParameter(type=example_gen_pb2.Input) proto_parameter.type_check('proto_parameter', example_gen_pb2.Input()) proto_parameter.type_check('proto_parameter', {'splits': [{ 'name': 'hello' }]}) proto_parameter.type_check('proto_parameter', {'wrong_field': 42}) with self.assertRaisesRegex( TypeError, "Expected type <class 'tfx.proto.example_gen_pb2.Input'>"): proto_parameter.type_check('proto_parameter', 42) with self.assertRaises(json_format.ParseError): proto_parameter.type_check('proto_parameter', {'splits': 42})
def resolve_exec_properties( self, exec_properties: Dict[Text, Any], pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> Dict[Text, Any]: """Overrides BaseDriver.resolve_exec_properties().""" del pipeline_info, component_info input_config = example_gen_pb2.Input() proto_utils.json_to_proto( exec_properties[standard_component_specs.INPUT_CONFIG_KEY], input_config) input_base = exec_properties.get( standard_component_specs.INPUT_BASE_KEY) logging.debug('Processing input %s.', input_base) range_config = None range_config_entry = exec_properties.get( standard_component_specs.RANGE_CONFIG_KEY) if range_config_entry: range_config = range_config_pb2.RangeConfig() proto_utils.json_to_proto(range_config_entry, range_config) processor = self.get_input_processor(splits=input_config.splits, range_config=range_config, input_base_uri=input_base) span, version = processor.resolve_span_and_version() fingerprint = processor.get_input_fingerprint(span, version) # Updates the input_config.splits.pattern. for split in input_config.splits: split.pattern = processor.get_pattern_for_span_version( split.pattern, span, version) exec_properties[standard_component_specs. INPUT_CONFIG_KEY] = proto_utils.proto_to_json( input_config) exec_properties[utils.SPAN_PROPERTY_NAME] = span exec_properties[utils.VERSION_PROPERTY_NAME] = version exec_properties[utils.FINGERPRINT_PROPERTY_NAME] = fingerprint return exec_properties
def testConstructWithInputConfig(self): presto_example_gen = component.PrestoExampleGen( self.conn_config, input_config=example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='train', pattern='query1'), example_gen_pb2.Input.Split(name='eval', pattern='query2'), example_gen_pb2.Input.Split(name='test', pattern='query3') ])) self.assertEqual( self.conn_config, self._extract_conn_config( presto_example_gen.exec_properties['custom_config'])) self.assertEqual('ExamplesPath', presto_example_gen.outputs['examples'].type_name) artifact_collection = presto_example_gen.outputs['examples'].get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split) self.assertEqual('test', artifact_collection[2].split)
def setUp(self): super(ExampleGenComponentWithAvroExecutorTest, self).setUp() # Create input_base. input_data_dir = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'testdata') self.avro_dir_path = os.path.join(input_data_dir, 'external') # Create input_config. self.input_config = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='avro', pattern='avro/*.avro'), ]) # Create output_config. self.output_config = example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ]))
def testResolveInputArtifactsWithSpan(self): # Test align of span number. span1_split1 = os.path.join(self._input_base_path, 'span01', 'split1', 'data') io_utils.write_string_file(span1_split1, 'testing11') span1_split2 = os.path.join(self._input_base_path, 'span01', 'split2', 'data') io_utils.write_string_file(span1_split2, 'testing12') span2_split1 = os.path.join(self._input_base_path, 'span02', 'split1', 'data') io_utils.write_string_file(span2_split1, 'testing21') with self.assertRaisesRegexp( ValueError, 'Latest span should be the same for each split'): self._example_gen_driver.resolve_input_artifacts(self._input_channels, self._exec_properties, None, None) # Test if latest span is selected when span aligns for each split. span2_split2 = os.path.join(self._input_base_path, 'span02', 'split2', 'data') io_utils.write_string_file(span2_split2, 'testing22') self._mock_metadata.get_artifacts_by_uri.return_value = [] self._mock_metadata.publish_artifacts.return_value = [ metadata_store_pb2.Artifact() ] self._example_gen_driver.resolve_input_artifacts(self._input_channels, self._exec_properties, None, None) updated_input_config = example_gen_pb2.Input() json_format.Parse(self._exec_properties['input_config'], updated_input_config) # Check if latest span is selected. self.assertProtoEquals( """ splits { name: "s1" pattern: "span02/split1/*" } splits { name: "s2" pattern: "span02/split2/*" }""", updated_input_config)
def testDriverWithSpan(self): # Test align of span number. span1_split1 = os.path.join(_TEST_INPUT_DIR, 'span01', 'split1', 'data') io_utils.write_string_file(span1_split1, 'testing11') span1_split2 = os.path.join(_TEST_INPUT_DIR, 'span01', 'split2', 'data') io_utils.write_string_file(span1_split2, 'testing12') span2_split1 = os.path.join(_TEST_INPUT_DIR, 'span02', 'split1', 'data') io_utils.write_string_file(span2_split1, 'testing21') serialized_args = [ 'driver.py', '--json_serialized_invocation_args', json_format.MessageToJson(message=self._executor_invocation) ] with self.assertRaisesRegexp( ValueError, 'Latest span should be the same for each split'): driver.main(serialized_args) # Test if latest span is selected when span aligns for each split. span2_split2 = os.path.join(_TEST_INPUT_DIR, 'span02', 'split2', 'data') io_utils.write_string_file(span2_split2, 'testing22') driver.main(serialized_args) # Check the output metadata file for the expected outputs with open(_TEST_OUTPUT_METADATA_JSON) as output_meta_json: output_metadata = pipeline_pb2.ExecutorOutput() json_format.Parse(output_meta_json.read(), output_metadata, ignore_unknown_fields=True) self.assertEqual(output_metadata.parameters['span'].string_value, '2') self.assertEqual( output_metadata.parameters['input_config'].string_value, json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='s1', pattern='span02/split1/*'), example_gen_pb2.Input.Split(name='s2', pattern='span02/split2/*') ])))
def setUp(self): super().setUp() # Create input_base. input_data_dir = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'testdata') self.parquet_dir_path = os.path.join(input_data_dir, 'external') # Create input_config. self.input_config = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='parquet', pattern='parquet/*.parquet'), ]) # Create output_config. self.output_config = example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ]))
def testConstructWithInputConfig(self): presto_example_gen = component.PrestoExampleGen( self.conn_config, input_config=example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='train', pattern='query1'), example_gen_pb2.Input.Split(name='eval', pattern='query2'), example_gen_pb2.Input.Split(name='test', pattern='query3') ])) self.assertEqual( self.conn_config, self._extract_conn_config( presto_example_gen.exec_properties['custom_config'])) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, presto_example_gen.outputs['examples'].type_name) artifact_collection = presto_example_gen.outputs['examples'].get() self.assertEqual(1, len(artifact_collection)) self.assertEqual(['train', 'eval', 'test'], artifact_utils.decode_split_names( artifact_collection[0].split_names))
def testBuildFileBasedExampleGenWithInputConfig(self): input_config = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='train', pattern='*train.tfr'), example_gen_pb2.Input.Split(name='eval', pattern='*test.tfr') ]) example_gen = components.ImportExampleGen( input_base='path/to/data/root', input_config=input_config) deployment_config = pipeline_pb2.PipelineDeploymentConfig() my_builder = step_builder.StepBuilder( node=example_gen, image='gcr.io/tensorflow/tfx:latest', deployment_config=deployment_config) actual_step_spec = self._sole(my_builder.build()) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_import_example_gen.pbtxt', pipeline_pb2.PipelineTaskSpec()), actual_step_spec) self.assertProtoEquals( test_utils.get_proto_from_test_data( 'expected_import_example_gen_executor.pbtxt', pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
def testPrestoToExample(self): with beam.Pipeline() as pipeline: examples = (pipeline | 'ToTFExample' >> executor._PrestoToExample( exec_properties={ 'input_config': proto_utils.proto_to_json(example_gen_pb2.Input()), 'custom_config': proto_utils.proto_to_json(example_gen_pb2.CustomConfig()) }, split_pattern='SELECT i, f, s FROM `fake`')) feature = {} feature['i'] = tf.train.Feature(int64_list=tf.train.Int64List( value=[1])) feature['f'] = tf.train.Feature(float_list=tf.train.FloatList( value=[2.0])) feature['s'] = tf.train.Feature(bytes_list=tf.train.BytesList( value=[tf.compat.as_bytes('abc')])) example_proto = tf.train.Example(features=tf.train.Features( feature=feature)) util.assert_that(examples, util.equal_to([example_proto]))