def testInputProcessor(self): input_config = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='s', pattern='path/{SPAN}'), ]) static_range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=1, end_span_number=2)) rolling_range_config = range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange(num_spans=2)) rolling_range_config2 = range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange(num_spans=1, start_span_number=1)) with self.assertRaisesRegexp( ValueError, 'For ExampleGen, start and end span numbers for RangeConfig.StaticRange must be equal' ): TestInputProcessor(input_config.splits, static_range_config) with self.assertRaisesRegexp( ValueError, 'ExampleGen only support single span for RangeConfig.RollingRange' ): TestInputProcessor(input_config.splits, rolling_range_config) with self.assertRaisesRegexp( ValueError, 'RangeConfig.rolling_range.start_span_number is not supported' ): TestInputProcessor(input_config.splits, rolling_range_config2)
def testInputProcessor(self): input_config = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='s1', pattern="select * from table where span={SPAN} and split='s1'" ), example_gen_pb2.Input.Split( name='s2', pattern="select * from table where and split='s2'") ]) input_config2 = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='s', pattern='select * from table'), ]) input_config3 = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='s', pattern='select * from table where span={SPAN}'), ]) static_range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=1, end_span_number=2)) rolling_range_config = range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange(num_spans=2)) rolling_range_config2 = range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange(num_spans=1, start_span_number=1)) with self.assertRaisesRegexp( ValueError, 'Spec setup should the same for all splits'): TestInputProcessor(input_config.splits, None) with self.assertRaisesRegexp(ValueError, 'Span or Date spec should be specified'): TestInputProcessor(input_config2.splits, static_range_config) with self.assertRaisesRegexp( ValueError, 'For ExampleGen, start and end span numbers for RangeConfig.StaticRange must be equal' ): TestInputProcessor(input_config3.splits, static_range_config) with self.assertRaisesRegexp( ValueError, 'ExampleGen only support single span for RangeConfig.RollingRange' ): TestInputProcessor(input_config3.splits, rolling_range_config) with self.assertRaisesRegexp( ValueError, 'RangeConfig.rolling_range.start_span_number is not supported' ): TestInputProcessor(input_config3.splits, rolling_range_config2)
def testResolveArtifacts(self): with metadata.Metadata(connection_config=self._connection_config) as m: artifact1 = self._createExamples(1) artifact2 = self._createExamples(2) artifact3 = self._createExamples(3) artifact4 = self._createExamples(4) artifact5 = self._createExamples(5) # Test StaticRange. resolver = spans_resolver.SpansResolver( range_config=range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange( start_span_number=2, end_span_number=3))) result = resolver.resolve_artifacts(m, { 'input': [artifact1, artifact2, artifact3, artifact4, artifact5] }) self.assertIsNotNone(result) self.assertEqual({a.uri for a in result['input']}, {artifact3.uri, artifact2.uri}) # Test RollingRange. resolver = spans_resolver.SpansResolver( range_config=range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange(num_spans=3))) result = resolver.resolve_artifacts(m, { 'input': [artifact1, artifact2, artifact3, artifact4, artifact5] }) self.assertIsNotNone(result) self.assertEqual([a.uri for a in result['input']], [artifact5.uri, artifact4.uri, artifact3.uri]) # Test RollingRange with start_span_number. resolver = spans_resolver.SpansResolver( range_config=range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange( start_span_number=4, num_spans=3))) result = resolver.resolve_artifacts(m, { 'input': [artifact1, artifact2, artifact3, artifact4, artifact5] }) self.assertIsNotNone(result) self.assertEqual([a.uri for a in result['input']], [artifact5.uri, artifact4.uri])
def __init__(self, splits: Iterable[example_gen_pb2.Input.Split], range_config: Optional[range_config_pb2.RangeConfig] = None): """Initialize InputProcessor. Args: splits: An iterable collection of example_gen_pb2.Input.Split objects. range_config: An instance of range_config_pb2.RangeConfig, defines the rules for span resolving. """ self._is_match_span = None self._is_match_date = None self._is_match_version = None for split in splits: is_match_span, is_match_date, is_match_version = utils.verify_split_pattern_specs( split) if self._is_match_span is None: self._is_match_span = is_match_span self._is_match_date = is_match_date self._is_match_version = is_match_version elif (self._is_match_span != is_match_span or self._is_match_date != is_match_date or self._is_match_version != is_match_version): raise ValueError('Spec setup should the same for all splits: %s.' % split.pattern) self._splits = splits if (self._is_match_span or self._is_match_date) and not range_config: range_config = range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange(num_spans=1)) if not self._is_match_span and not self._is_match_date and range_config: raise ValueError( 'Span or Date spec should be specified in split pattern if RangeConfig is specified.' ) if range_config: if range_config.HasField('static_range'): start_span_number = range_config.static_range.start_span_number end_span_number = range_config.static_range.end_span_number if start_span_number != end_span_number: raise ValueError( 'For ExampleGen, start and end span numbers for RangeConfig.StaticRange must be equal.' ) elif range_config.HasField('rolling_range'): if range_config.rolling_range.num_spans != 1: raise ValueError( 'ExampleGen only support single span for RangeConfig.RollingRange.' ) if range_config.rolling_range.start_span_number > 0: raise ValueError( 'RangeConfig.rolling_range.start_span_number is not supported.') else: raise ValueError('Only static_range and rolling_range are supported.') self._range_config = range_config
def testStrategy_IrMode(self): artifact1 = self._createExamples(1) artifact2 = self._createExamples(2) artifact3 = self._createExamples(3) artifact4 = self._createExamples(4) artifact5 = self._createExamples(5) # Test StaticRange. resolver = span_range_strategy.SpanRangeStrategy( range_config=range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=2, end_span_number=3))) result = resolver.resolve_artifacts( self._store, {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]}) self.assertIsNotNone(result) self.assertEqual({a.uri for a in result['input']}, {artifact3.uri, artifact2.uri}) # Test RollingRange. resolver = span_range_strategy.SpanRangeStrategy( range_config=range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange(num_spans=3))) result = resolver.resolve_artifacts( self._store, {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]}) self.assertIsNotNone(result) self.assertEqual([a.uri for a in result['input']], [artifact5.uri, artifact4.uri, artifact3.uri]) # Test RollingRange with start_span_number. resolver = span_range_strategy.SpanRangeStrategy( range_config=range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange( start_span_number=4, num_spans=3))) result = resolver.resolve_artifacts( self._store, {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]}) self.assertIsNotNone(result) self.assertEqual([a.uri for a in result['input']], [artifact5.uri, artifact4.uri])
def testQueryBasedInputProcessor(self): input_config = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='s', pattern='select * from table'), ]) input_config_span = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='s1', pattern='select * from table where date=@span_yyyymmdd_utc'), example_gen_pb2.Input.Split( name='s2', pattern='select * from table2 where date=@span_yyyymmdd_utc') ]) static_range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=2, end_span_number=2)) rolling_range_config = range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange(num_spans=1)) with self.assertRaisesRegexp( NotImplementedError, 'For QueryBasedExampleGen, latest Span is not supported'): processor = input_processor.QueryBasedInputProcessor( input_config_span.splits, rolling_range_config) processor.resolve_span_and_version() processor = input_processor.QueryBasedInputProcessor( input_config.splits, None) span, version = processor.resolve_span_and_version() fp = processor.get_input_fingerprint(span, version) self.assertEqual(span, 0) self.assertIsNone(version) self.assertIsNone(fp) processor = input_processor.QueryBasedInputProcessor( input_config_span.splits, static_range_config) span, version = processor.resolve_span_and_version() fp = processor.get_input_fingerprint(span, version) self.assertEqual(span, 2) self.assertIsNone(version) self.assertIsNone(fp) pattern = processor.get_pattern_for_span_version( input_config_span.splits[0].pattern, span, version) self.assertEqual(pattern, "select * from table where date='19700103'")
def __init__(self, input_base_uri: Text, splits: Iterable[example_gen_pb2.Input.Split], range_config: Optional[range_config_pb2.RangeConfig] = None): """Initialize FileBasedInputProcessor. Args: input_base_uri: The base path from which files will be searched. splits: An iterable collection of example_gen_pb2.Input.Split objects. range_config: An instance of range_config_pb2.RangeConfig, defines the rules for span resolving. """ super(FileBasedInputProcessor, self).__init__(splits=splits, range_config=range_config) self._is_match_span = None self._is_match_date = None self._is_match_version = None for split in splits: is_match_span, is_match_date, is_match_version = utils.verify_split_pattern_specs( split) if self._is_match_span is None: self._is_match_span = is_match_span self._is_match_date = is_match_date self._is_match_version = is_match_version elif (self._is_match_span != is_match_span or self._is_match_date != is_match_date or self._is_match_version != is_match_version): raise ValueError( 'Spec setup should the same for all splits: %s.' % split.pattern) if (self._is_match_span or self._is_match_date) and not range_config: range_config = range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange(num_spans=1)) if not self._is_match_span and not self._is_match_date and range_config: raise ValueError( 'Span or Date spec should be specified in split pattern if RangeConfig is specified.' ) self._input_base_uri = input_base_uri self._fingerprint = None
def testPenguinPipelineLocalWithRollingWindow(self): examplegen_input_config = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='test', pattern='day{SPAN}/*'), ]) resolver_range_config = range_config_pb2.RangeConfig( rolling_range=range_config_pb2.RollingRange(num_spans=2)) def run_pipeline(examplegen_range_config): LocalDagRunner().run( penguin_pipeline_local._create_pipeline( pipeline_name=self._pipeline_name, data_root=self._data_root_span, module_file=self._module_file, accuracy_threshold=0.1, serving_model_dir=self._serving_model_dir, pipeline_root=self._pipeline_root, metadata_path=self._metadata_path, enable_tuning=False, examplegen_input_config=examplegen_input_config, examplegen_range_config=examplegen_range_config, resolver_range_config=resolver_range_config, beam_pipeline_args=[])) # Trigger the pipeline for the first span. examplegen_range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=1, end_span_number=1)) run_pipeline(examplegen_range_config) self.assertTrue(fileio.exists(self._serving_model_dir)) self.assertTrue(fileio.exists(self._metadata_path)) self.assertPipelineExecution(False) transform_execution_type = 'tfx.components.transform.component.Transform' trainer_execution_type = 'tfx.components.trainer.component.Trainer' expected_execution_count = 10 # 8 components + 2 resolver metadata_config = metadata.sqlite_metadata_connection_config( self._metadata_path) with metadata.Metadata(metadata_config) as m: artifact_count = len(m.store.get_artifacts()) execution_count = len(m.store.get_executions()) self.assertGreaterEqual(artifact_count, execution_count) self.assertEqual(expected_execution_count, execution_count) # Verify Transform's input examples artifacts. tft_input_examples_artifacts = self._get_input_examples_artifacts( m.store, transform_execution_type) self.assertLen(tft_input_examples_artifacts, 1) # SpansResolver (controlled by resolver_range_config) returns span 1. self.assertEqual( '1', tft_input_examples_artifacts[0].custom_properties[ utils.SPAN_PROPERTY_NAME].string_value) # Trigger the pipeline for the second span. examplegen_range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=2, end_span_number=2)) run_pipeline(examplegen_range_config) with metadata.Metadata(metadata_config) as m: execution_count = len(m.store.get_executions()) self.assertEqual(expected_execution_count * 2, execution_count) # Verify Transform's input examples artifacts. tft_input_examples_artifacts = self._get_input_examples_artifacts( m.store, transform_execution_type) self.assertLen(tft_input_examples_artifacts, 2) spans = { tft_input_examples_artifacts[0].custom_properties[ utils.SPAN_PROPERTY_NAME].string_value, tft_input_examples_artifacts[1].custom_properties[ utils.SPAN_PROPERTY_NAME].string_value } # SpansResolver (controlled by resolver_range_config) returns span 1 & 2. self.assertSetEqual({'1', '2'}, spans) # Verify Trainer's input examples artifacts. self.assertLen( self._get_input_examples_artifacts(m.store, trainer_execution_type), 2) # Trigger the pipeline for the thrid span. examplegen_range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=3, end_span_number=3)) run_pipeline(examplegen_range_config) metadata_config = metadata.sqlite_metadata_connection_config( self._metadata_path) with metadata.Metadata(metadata_config) as m: execution_count = len(m.store.get_executions()) self.assertEqual(expected_execution_count * 3, execution_count) # Verify Transform's input examples artifacts. tft_input_examples_artifacts = self._get_input_examples_artifacts( m.store, transform_execution_type) self.assertLen(tft_input_examples_artifacts, 2) spans = { tft_input_examples_artifacts[0].custom_properties[ utils.SPAN_PROPERTY_NAME].string_value, tft_input_examples_artifacts[1].custom_properties[ utils.SPAN_PROPERTY_NAME].string_value } # SpansResolver (controlled by resolver_range_config) returns span 2 & 3. self.assertSetEqual({'2', '3'}, spans) # Verify Trainer's input examples artifacts. self.assertLen( self._get_input_examples_artifacts(m.store, trainer_execution_type), 2)