Beispiel #1
0
 def get_input_processor(
     self,
     splits: Iterable[example_gen_pb2.Input.Split],
     range_config: Optional[range_config_pb2.RangeConfig] = None,
     input_base_uri: Optional[Text] = None
 ) -> input_processor.InputProcessor:
     """Returns QueryBasedInputProcessor."""
     return input_processor.QueryBasedInputProcessor(splits, range_config)
Beispiel #2
0
    def testQueryBasedInputProcessor(self):
        input_config = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='s',
                                        pattern='select * from table'),
        ])
        input_config_span = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(
                name='s1',
                pattern='select * from table where date=@span_yyyymmdd_utc'),
            example_gen_pb2.Input.Split(
                name='s2',
                pattern='select * from table2 where date=@span_yyyymmdd_utc')
        ])

        static_range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=2,
                                                      end_span_number=2))
        rolling_range_config = range_config_pb2.RangeConfig(
            rolling_range=range_config_pb2.RollingRange(num_spans=1))

        with self.assertRaisesRegexp(
                NotImplementedError,
                'For QueryBasedExampleGen, latest Span is not supported'):
            processor = input_processor.QueryBasedInputProcessor(
                input_config_span.splits, rolling_range_config)
            processor.resolve_span_and_version()

        processor = input_processor.QueryBasedInputProcessor(
            input_config.splits, None)
        span, version = processor.resolve_span_and_version()
        fp = processor.get_input_fingerprint(span, version)
        self.assertEqual(span, 0)
        self.assertIsNone(version)
        self.assertIsNone(fp)

        processor = input_processor.QueryBasedInputProcessor(
            input_config_span.splits, static_range_config)
        span, version = processor.resolve_span_and_version()
        fp = processor.get_input_fingerprint(span, version)
        self.assertEqual(span, 2)
        self.assertIsNone(version)
        self.assertIsNone(fp)
        pattern = processor.get_pattern_for_span_version(
            input_config_span.splits[0].pattern, span, version)
        self.assertEqual(pattern, "select * from table where date='19700103'")