def testRangeConfigSpanWidthPresence(self): # Test RangeConfig.static_range behavior when span width is not given. span1_split1 = os.path.join(self._input_base_path, 'span01', 'split1', 'data') io_utils.write_string_file(span1_split1, 'testing11') range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=1, end_span_number=1)) splits1 = [ example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN}/split1/*') ] # RangeConfig cannot find zero padding span without width modifier. with self.assertRaisesRegexp(ValueError, 'Cannot find matching for split'): utils.calculate_splits_fingerprint_span_and_version( self._input_base_path, splits1, range_config=range_config) splits2 = [ example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN:2}/split1/*') ] # With width modifier in span spec, RangeConfig.static_range makes # correct zero-padded substitution. _, span, version = utils.calculate_splits_fingerprint_span_and_version( self._input_base_path, splits2, range_config=range_config) self.assertEqual(span, 1) self.assertIsNone(version) self.assertEqual(splits2[0].pattern, 'span01/split1/*')
def testFileBasedInputProcessor(self): # TODO(b/181275944): migrate test after refactoring FileBasedInputProcessor. input_config = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='s1', pattern='path/{SPAN}'), example_gen_pb2.Input.Split(name='s2', pattern='path2') ]) input_config2 = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='s', pattern='path'), ]) static_range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=2, end_span_number=2)) with self.assertRaisesRegexp( ValueError, 'Spec setup should the same for all splits'): input_processor.FileBasedInputProcessor('input_base_uri', input_config.splits, None) with self.assertRaisesRegexp(ValueError, 'Span or Date spec should be specified'): input_processor.FileBasedInputProcessor('input_base_uri', input_config2.splits, static_range_config)
def testInputProcessor(self): input_config = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='s', pattern='path/{SPAN}'), ]) static_range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=1, end_span_number=2)) rolling_range_config = range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange(num_spans=2)) rolling_range_config2 = range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange(num_spans=1, start_span_number=1)) with self.assertRaisesRegexp( ValueError, 'For ExampleGen, start and end span numbers for RangeConfig.StaticRange must be equal' ): TestInputProcessor(input_config.splits, static_range_config) with self.assertRaisesRegexp( ValueError, 'ExampleGen only support single span for RangeConfig.RollingRange' ): TestInputProcessor(input_config.splits, rolling_range_config) with self.assertRaisesRegexp( ValueError, 'RangeConfig.rolling_range.start_span_number is not supported' ): TestInputProcessor(input_config.splits, rolling_range_config2)
def testConstructSubclassQueryBasedWithRangeConfig(self): # @span_yyyymmdd_utc will replaced to '19700103' to query, span `2` will be # recorded in output Example artifact. range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=2, end_span_number=2)) example_gen = TestQueryBasedExampleGenComponent( input_config=example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='single', pattern='select * from table where date=@span_yyyymmdd_utc' ), ]), range_config=range_config) self.assertEqual({}, example_gen.inputs) self.assertEqual(driver.QueryBasedDriver, example_gen.driver_class) self.assertEqual( standard_artifacts.Examples.TYPE_NAME, example_gen.outputs[ standard_component_specs.EXAMPLES_KEY].type_name) stored_range_config = range_config_pb2.RangeConfig() proto_utils.json_to_proto( example_gen.exec_properties[ standard_component_specs.RANGE_CONFIG_KEY], stored_range_config) self.assertEqual(range_config, stored_range_config)
def testQueryBasedDriver(self): # Create exec proterties. exec_properties = { standard_component_specs.INPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='s1', pattern= "select * from table where span={SPAN} and split='s1'" ), example_gen_pb2.Input.Split( name='s2', pattern= "select * from table where span={SPAN} and split='s2'") ])), standard_component_specs.RANGE_CONFIG_KEY: proto_utils.proto_to_json( range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange( start_span_number=2, end_span_number=2))), } # Prepare output_dict example = standard_artifacts.Examples() example.uri = 'my_uri' output_dict = {standard_component_specs.EXAMPLES_KEY: [example]} query_based_driver = driver.QueryBasedDriver(self._mock_metadata) result = query_based_driver.run( portable_data_types.ExecutionInfo(output_dict=output_dict, exec_properties=exec_properties)) self.assertEqual(exec_properties[utils.SPAN_PROPERTY_NAME], 2) self.assertIsNone(exec_properties[utils.VERSION_PROPERTY_NAME]) self.assertIsNone(exec_properties[utils.FINGERPRINT_PROPERTY_NAME]) updated_input_config = example_gen_pb2.Input() proto_utils.json_to_proto( exec_properties[standard_component_specs.INPUT_CONFIG_KEY], updated_input_config) self.assertProtoEquals( """ splits { name: "s1" pattern: "select * from table where span=2 and split='s1'" } splits { name: "s2" pattern: "select * from table where span=2 and split='s2'" }""", updated_input_config) self.assertLen( result.output_artifacts[ standard_component_specs.EXAMPLES_KEY].artifacts, 1) output_example = result.output_artifacts[ standard_component_specs.EXAMPLES_KEY].artifacts[0] self.assertEqual(output_example.uri, example.uri) self.assertEqual( output_example.custom_properties[ utils.SPAN_PROPERTY_NAME].string_value, '2')
def testConstructWithStaticRangeConfig(self): range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=1, end_span_number=1)) example_gen = component.FileBasedExampleGen( input_base='path', range_config=range_config, custom_executor_spec=executor_spec.ExecutorClassSpec( TestExampleGenExecutor)) stored_range_config = range_config_pb2.RangeConfig() json_format.Parse(example_gen.exec_properties['range_config'], stored_range_config) self.assertEqual(range_config, stored_range_config)
def testInputProcessor(self): input_config = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='s1', pattern="select * from table where span={SPAN} and split='s1'" ), example_gen_pb2.Input.Split( name='s2', pattern="select * from table where and split='s2'") ]) input_config2 = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='s', pattern='select * from table'), ]) input_config3 = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='s', pattern='select * from table where span={SPAN}'), ]) static_range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=1, end_span_number=2)) rolling_range_config = range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange(num_spans=2)) rolling_range_config2 = range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange(num_spans=1, start_span_number=1)) with self.assertRaisesRegexp( ValueError, 'Spec setup should the same for all splits'): TestInputProcessor(input_config.splits, None) with self.assertRaisesRegexp(ValueError, 'Span or Date spec should be specified'): TestInputProcessor(input_config2.splits, static_range_config) with self.assertRaisesRegexp( ValueError, 'For ExampleGen, start and end span numbers for RangeConfig.StaticRange must be equal' ): TestInputProcessor(input_config3.splits, static_range_config) with self.assertRaisesRegexp( ValueError, 'ExampleGen only support single span for RangeConfig.RollingRange' ): TestInputProcessor(input_config3.splits, rolling_range_config) with self.assertRaisesRegexp( ValueError, 'RangeConfig.rolling_range.start_span_number is not supported' ): TestInputProcessor(input_config3.splits, rolling_range_config2)
def testConstructWithStaticRangeConfig(self): range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange( start_span_number=1, end_span_number=1)) example_gen = component.FileBasedExampleGen( input_base='path', range_config=range_config, custom_executor_spec=executor_spec.BeamExecutorSpec( TestExampleGenExecutor)) stored_range_config = range_config_pb2.RangeConfig() proto_utils.json_to_proto( example_gen.exec_properties[standard_component_specs.RANGE_CONFIG_KEY], stored_range_config) self.assertEqual(range_config, stored_range_config)
def testRangeConfigWithNonexistentSpan(self): # Test behavior when specified span in RangeConfig does not exist. span1_split1 = os.path.join(self._input_base_path, 'span01', 'split1', 'data') io_utils.write_string_file(span1_split1, 'testing11') range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange( start_span_number=2, end_span_number=2)) splits = [ example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN:2}/split1/*') ] with self.assertRaisesRegex(ValueError, 'Cannot find matching for split'): utils.calculate_splits_fingerprint_span_and_version( self._input_base_path, splits, range_config=range_config)
def testResolveArtifacts(self): with metadata.Metadata(connection_config=self._connection_config) as m: artifact1 = self._createExamples(1) artifact2 = self._createExamples(2) artifact3 = self._createExamples(3) artifact4 = self._createExamples(4) artifact5 = self._createExamples(5) # Test StaticRange. resolver = spans_resolver.SpansResolver( range_config=range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange( start_span_number=2, end_span_number=3))) result = resolver.resolve_artifacts(m, { 'input': [artifact1, artifact2, artifact3, artifact4, artifact5] }) self.assertIsNotNone(result) self.assertEqual({a.uri for a in result['input']}, {artifact3.uri, artifact2.uri}) # Test RollingRange. resolver = spans_resolver.SpansResolver( range_config=range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange(num_spans=3))) result = resolver.resolve_artifacts(m, { 'input': [artifact1, artifact2, artifact3, artifact4, artifact5] }) self.assertIsNotNone(result) self.assertEqual([a.uri for a in result['input']], [artifact5.uri, artifact4.uri, artifact3.uri]) # Test RollingRange with start_span_number. resolver = spans_resolver.SpansResolver( range_config=range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange( start_span_number=4, num_spans=3))) result = resolver.resolve_artifacts(m, { 'input': [artifact1, artifact2, artifact3, artifact4, artifact5] }) self.assertIsNotNone(result) self.assertEqual([a.uri for a in result['input']], [artifact5.uri, artifact4.uri])
def testQueryBasedInputProcessor(self): input_config = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='s', pattern='select * from table'), ]) input_config_span = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='s1', pattern='select * from table where date=@span_yyyymmdd_utc'), example_gen_pb2.Input.Split( name='s2', pattern='select * from table2 where date=@span_yyyymmdd_utc') ]) static_range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=2, end_span_number=2)) rolling_range_config = range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange(num_spans=1)) with self.assertRaisesRegexp( NotImplementedError, 'For QueryBasedExampleGen, latest Span is not supported'): processor = input_processor.QueryBasedInputProcessor( input_config_span.splits, rolling_range_config) processor.resolve_span_and_version() processor = input_processor.QueryBasedInputProcessor( input_config.splits, None) span, version = processor.resolve_span_and_version() fp = processor.get_input_fingerprint(span, version) self.assertEqual(span, 0) self.assertIsNone(version) self.assertIsNone(fp) processor = input_processor.QueryBasedInputProcessor( input_config_span.splits, static_range_config) span, version = processor.resolve_span_and_version() fp = processor.get_input_fingerprint(span, version) self.assertEqual(span, 2) self.assertIsNone(version) self.assertIsNone(fp) pattern = processor.get_pattern_for_span_version( input_config_span.splits[0].pattern, span, version) self.assertEqual(pattern, "select * from table where date='19700103'")
def testConstructWithRangeConfig(self): range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=2, end_span_number=2)) # @span_yyyymmdd_utc will replaced to '19700103' to query, span `2` will be # recorded in output Example artifact. big_query_example_gen = component.BigQueryExampleGen( query='select * from table where date=@span_yyyymmdd_utc', range_config=range_config) self.assertEqual( standard_artifacts.Examples.TYPE_NAME, big_query_example_gen.outputs[ standard_component_specs.EXAMPLES_KEY].type_name) stored_range_config = range_config_pb2.RangeConfig() proto_utils.json_to_proto( big_query_example_gen.exec_properties[ standard_component_specs.RANGE_CONFIG_KEY], stored_range_config) self.assertEqual(range_config, stored_range_config)
def testStrategy_IrMode(self): artifact1 = self._createExamples(1) artifact2 = self._createExamples(2) artifact3 = self._createExamples(3) artifact4 = self._createExamples(4) artifact5 = self._createExamples(5) # Test StaticRange. resolver = span_range_strategy.SpanRangeStrategy( range_config=range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=2, end_span_number=3))) result = resolver.resolve_artifacts( self._store, {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]}) self.assertIsNotNone(result) self.assertEqual({a.uri for a in result['input']}, {artifact3.uri, artifact2.uri}) # Test RollingRange. resolver = span_range_strategy.SpanRangeStrategy( range_config=range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange(num_spans=3))) result = resolver.resolve_artifacts( self._store, {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]}) self.assertIsNotNone(result) self.assertEqual([a.uri for a in result['input']], [artifact5.uri, artifact4.uri, artifact3.uri]) # Test RollingRange with start_span_number. resolver = span_range_strategy.SpanRangeStrategy( range_config=range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange( start_span_number=4, num_spans=3))) result = resolver.resolve_artifacts( self._store, {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]}) self.assertIsNotNone(result) self.assertEqual([a.uri for a in result['input']], [artifact5.uri, artifact4.uri])
def testResolve(self): with metadata.Metadata(connection_config=self._connection_config) as m: contexts = m.register_pipeline_contexts_if_not_exists( self._pipeline_info) artifact_one = standard_artifacts.Examples() artifact_one.uri = 'uri_one' artifact_one.set_string_custom_property(utils.SPAN_PROPERTY_NAME, '1') m.publish_artifacts([artifact_one]) artifact_two = standard_artifacts.Examples() artifact_two.uri = 'uri_two' artifact_two.set_string_custom_property(utils.SPAN_PROPERTY_NAME, '2') m.register_execution(exec_properties={}, pipeline_info=self._pipeline_info, component_info=self._component_info, contexts=contexts) m.publish_execution( component_info=self._component_info, output_artifacts={'key': [artifact_one, artifact_two]}) resolver = spans_resolver.SpansResolver( range_config=range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange( start_span_number=1, end_span_number=1))) resolve_result = resolver.resolve( pipeline_info=self._pipeline_info, metadata_handler=m, source_channels={ 'input': types.Channel(type=artifact_one.type, producer_component_id=self._component_info. component_id, output_key='key') }) self.assertTrue(resolve_result.has_complete_result) self.assertEqual([ artifact.uri for artifact in resolve_result.per_key_resolve_result['input'] ], [artifact_one.uri]) self.assertTrue(resolve_result.per_key_resolve_state['input'])
class FactoryTest(tf.test.TestCase, parameterized.TestCase): def class_path_exists(self, class_path): module_name, unused_class_name = class_path.rsplit('.', maxsplit=1) try: importlib.import_module(module_name) except ImportError: return False else: return True @parameterized.parameters( ('tfx.dsl.resolvers.oldest_artifacts_resolver' '.OldestArtifactsResolver', '{}'), ('tfx.dsl.resolvers.unprocessed_artifacts_resolver' '.UnprocessedArtifactsResolver', '{"execution_type_name": "Foo"}'), ('tfx.dsl.input_resolution.strategies.latest_artifact_strategy' '.LatestArtifactStrategy', '{}'), ('tfx.dsl.input_resolution.strategies.latest_blessed_model_strategy' '.LatestBlessedModelStrategy', '{}'), ('tfx.dsl.input_resolution.strategies.span_range_strategy' '.SpanRangeStrategy', json_utils.dumps({ 'range_config': range_config_pb2.StaticRange( start_span_number=1, end_span_number=10) })), ) def test_make_resolver_strategy_instance(self, class_path, config_json): if not self.class_path_exists(class_path): self.skipTest(f"Class path {class_path} doesn't exist.") resolver_step = pipeline_pb2.ResolverConfig.ResolverStep( class_path=class_path, config_json=config_json) result = factory.make_resolver_strategy_instance(resolver_step) self.assertIsInstance(result, resolver.ResolverStrategy) self.assertEndsWith(class_path, result.__class__.__name__)
def testRangeConfigWithDateSpec(self): span1_split1 = os.path.join(self._input_base_path, '19700102', 'split1', 'data') io_utils.write_string_file(span1_split1, 'testing11') start_span = utils.date_to_span_number(1970, 1, 2) end_span = utils.date_to_span_number(1970, 1, 2) range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange( start_span_number=start_span, end_span_number=end_span)) splits = [ example_gen_pb2.Input.Split(name='s1', pattern='{YYYY}{MM}{DD}/split1/*') ] _, span, version = utils.calculate_splits_fingerprint_span_and_version( self._input_base_path, splits, range_config=range_config) self.assertEqual(span, 1) self.assertIsNone(version) self.assertEqual(splits[0].pattern, '19700102/split1/*')
def testSpanAlignWithRangeConfig(self): span1_split1 = os.path.join(self._input_base_path, 'span01', 'split1', 'data') io_utils.write_string_file(span1_split1, 'testing11') span2_split1 = os.path.join(self._input_base_path, 'span02', 'split1', 'data') io_utils.write_string_file(span2_split1, 'testing21') # Test static range in RangeConfig. range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange( start_span_number=1, end_span_number=1)) splits = [ example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN:2}/split1/*') ] _, span, version = utils.calculate_splits_fingerprint_span_and_version( self._input_base_path, splits, range_config) self.assertEqual(span, 1) self.assertIsNone(version) self.assertEqual(splits[0].pattern, 'span01/split1/*')
def testPenguinPipelineLocalWithRollingWindow(self): examplegen_input_config = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='test', pattern='day{SPAN}/*'), ]) resolver_range_config = range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange(num_spans=2)) def run_pipeline(examplegen_range_config): LocalDagRunner().run( penguin_pipeline_local._create_pipeline( pipeline_name=self._pipeline_name, data_root=self._data_root_span, module_file=self._module_file, accuracy_threshold=0.1, serving_model_dir=self._serving_model_dir, pipeline_root=self._pipeline_root, metadata_path=self._metadata_path, enable_tuning=False, examplegen_input_config=examplegen_input_config, examplegen_range_config=examplegen_range_config, resolver_range_config=resolver_range_config, beam_pipeline_args=[])) # Trigger the pipeline for the first span. examplegen_range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=1, end_span_number=1)) run_pipeline(examplegen_range_config) self.assertTrue(fileio.exists(self._serving_model_dir)) self.assertTrue(fileio.exists(self._metadata_path)) self.assertPipelineExecution(False) transform_execution_type = 'tfx.components.transform.component.Transform' trainer_execution_type = 'tfx.components.trainer.component.Trainer' expected_execution_count = 10 # 8 components + 2 resolver metadata_config = metadata.sqlite_metadata_connection_config( self._metadata_path) with metadata.Metadata(metadata_config) as m: artifact_count = len(m.store.get_artifacts()) execution_count = len(m.store.get_executions()) self.assertGreaterEqual(artifact_count, execution_count) self.assertEqual(expected_execution_count, execution_count) # Verify Transform's input examples artifacts. tft_input_examples_artifacts = self._get_input_examples_artifacts( m.store, transform_execution_type) self.assertLen(tft_input_examples_artifacts, 1) # SpansResolver (controlled by resolver_range_config) returns span 1. self.assertEqual( '1', tft_input_examples_artifacts[0].custom_properties[ utils.SPAN_PROPERTY_NAME].string_value) # Trigger the pipeline for the second span. examplegen_range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=2, end_span_number=2)) run_pipeline(examplegen_range_config) with metadata.Metadata(metadata_config) as m: execution_count = len(m.store.get_executions()) self.assertEqual(expected_execution_count * 2, execution_count) # Verify Transform's input examples artifacts. tft_input_examples_artifacts = self._get_input_examples_artifacts( m.store, transform_execution_type) self.assertLen(tft_input_examples_artifacts, 2) spans = { tft_input_examples_artifacts[0].custom_properties[ utils.SPAN_PROPERTY_NAME].string_value, tft_input_examples_artifacts[1].custom_properties[ utils.SPAN_PROPERTY_NAME].string_value } # SpansResolver (controlled by resolver_range_config) returns span 1 & 2. self.assertSetEqual({'1', '2'}, spans) # Verify Trainer's input examples artifacts. self.assertLen( self._get_input_examples_artifacts(m.store, trainer_execution_type), 2) # Trigger the pipeline for the thrid span. examplegen_range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=3, end_span_number=3)) run_pipeline(examplegen_range_config) metadata_config = metadata.sqlite_metadata_connection_config( self._metadata_path) with metadata.Metadata(metadata_config) as m: execution_count = len(m.store.get_executions()) self.assertEqual(expected_execution_count * 3, execution_count) # Verify Transform's input examples artifacts. tft_input_examples_artifacts = self._get_input_examples_artifacts( m.store, transform_execution_type) self.assertLen(tft_input_examples_artifacts, 2) spans = { tft_input_examples_artifacts[0].custom_properties[ utils.SPAN_PROPERTY_NAME].string_value, tft_input_examples_artifacts[1].custom_properties[ utils.SPAN_PROPERTY_NAME].string_value } # SpansResolver (controlled by resolver_range_config) returns span 2 & 3. self.assertSetEqual({'2', '3'}, spans) # Verify Trainer's input examples artifacts. self.assertLen( self._get_input_examples_artifacts(m.store, trainer_execution_type), 2)
def testResolveExecProperties(self): # Create input dir. self._input_base_path = os.path.join(self._test_dir, 'input_base') fileio.makedirs(self._input_base_path) # Create exec proterties. self._exec_properties = { standard_component_specs.INPUT_BASE_KEY: self._input_base_path, standard_component_specs.INPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='s1', pattern='span{SPAN:2}/version{VERSION:2}/split1/*'), example_gen_pb2.Input.Split( name='s2', pattern='span{SPAN:2}/version{VERSION:2}/split2/*') ])), standard_component_specs.RANGE_CONFIG_KEY: None, } # Test align of span number. span1_v1_split1 = os.path.join(self._input_base_path, 'span01', 'version01', 'split1', 'data') io_utils.write_string_file(span1_v1_split1, 'testing11') span1_v1_split2 = os.path.join(self._input_base_path, 'span01', 'version01', 'split2', 'data') io_utils.write_string_file(span1_v1_split2, 'testing12') span2_v1_split1 = os.path.join(self._input_base_path, 'span02', 'version01', 'split1', 'data') io_utils.write_string_file(span2_v1_split1, 'testing21') # Check that error raised when span does not match. with self.assertRaisesRegexp( ValueError, 'Latest span should be the same for each split'): self._file_based_driver.resolve_exec_properties( self._exec_properties, None, None) span2_v1_split2 = os.path.join(self._input_base_path, 'span02', 'version01', 'split2', 'data') io_utils.write_string_file(span2_v1_split2, 'testing22') span2_v2_split1 = os.path.join(self._input_base_path, 'span02', 'version02', 'split1', 'data') io_utils.write_string_file(span2_v2_split1, 'testing21') # Check that error raised when span matches, but version does not match. with self.assertRaisesRegexp( ValueError, 'Latest version should be the same for each split'): self._file_based_driver.resolve_exec_properties( self._exec_properties, None, None) span2_v2_split2 = os.path.join(self._input_base_path, 'span02', 'version02', 'split2', 'data') io_utils.write_string_file(span2_v2_split2, 'testing22') # Test if latest span and version selected when span and version aligns # for each split. self._file_based_driver.resolve_exec_properties( self._exec_properties, None, None) self.assertEqual(self._exec_properties[utils.SPAN_PROPERTY_NAME], 2) self.assertEqual(self._exec_properties[utils.VERSION_PROPERTY_NAME], 2) self.assertRegex( self._exec_properties[utils.FINGERPRINT_PROPERTY_NAME], r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*' ) updated_input_config = example_gen_pb2.Input() proto_utils.json_to_proto( self._exec_properties[standard_component_specs.INPUT_CONFIG_KEY], updated_input_config) # Check if latest span is selected. self.assertProtoEquals( """ splits { name: "s1" pattern: "span02/version02/split1/*" } splits { name: "s2" pattern: "span02/version02/split2/*" }""", updated_input_config) # Test driver behavior using RangeConfig with static range. self._exec_properties[ standard_component_specs. INPUT_CONFIG_KEY] = proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='s1', pattern='span{SPAN:2}/version{VERSION:2}/split1/*'), example_gen_pb2.Input.Split( name='s2', pattern='span{SPAN:2}/version{VERSION:2}/split2/*'), ])) self._exec_properties[ standard_component_specs. RANGE_CONFIG_KEY] = proto_utils.proto_to_json( range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange( start_span_number=1, end_span_number=2))) with self.assertRaisesRegexp( ValueError, 'For ExampleGen, start and end span numbers'): self._file_based_driver.resolve_exec_properties( self._exec_properties, None, None) self._exec_properties[ standard_component_specs. RANGE_CONFIG_KEY] = proto_utils.proto_to_json( range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange( start_span_number=1, end_span_number=1))) self._file_based_driver.resolve_exec_properties( self._exec_properties, None, None) self.assertEqual(self._exec_properties[utils.SPAN_PROPERTY_NAME], 1) self.assertEqual(self._exec_properties[utils.VERSION_PROPERTY_NAME], 1) self.assertRegex( self._exec_properties[utils.FINGERPRINT_PROPERTY_NAME], r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*' ) updated_input_config = example_gen_pb2.Input() proto_utils.json_to_proto( self._exec_properties[standard_component_specs.INPUT_CONFIG_KEY], updated_input_config) # Check if correct span inside static range is selected. self.assertProtoEquals( """ splits { name: "s1" pattern: "span01/version01/split1/*" } splits { name: "s2" pattern: "span01/version01/split2/*" }""", updated_input_config)