コード例 #1
0
    def testInputProcessor(self):
        input_config = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='s', pattern='path/{SPAN}'),
        ])
        static_range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=1,
                                                      end_span_number=2))
        rolling_range_config = range_config_pb2.RangeConfig(
            rolling_range=range_config_pb2.RollingRange(num_spans=2))
        rolling_range_config2 = range_config_pb2.RangeConfig(
            rolling_range=range_config_pb2.RollingRange(num_spans=1,
                                                        start_span_number=1))

        with self.assertRaisesRegexp(
                ValueError,
                'For ExampleGen, start and end span numbers for RangeConfig.StaticRange must be equal'
        ):
            TestInputProcessor(input_config.splits, static_range_config)

        with self.assertRaisesRegexp(
                ValueError,
                'ExampleGen only support single span for RangeConfig.RollingRange'
        ):
            TestInputProcessor(input_config.splits, rolling_range_config)

        with self.assertRaisesRegexp(
                ValueError,
                'RangeConfig.rolling_range.start_span_number is not supported'
        ):
            TestInputProcessor(input_config.splits, rolling_range_config2)
コード例 #2
0
    def testInputProcessor(self):
        input_config = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(
                name='s1',
                pattern="select * from table where span={SPAN} and split='s1'"
            ),
            example_gen_pb2.Input.Split(
                name='s2', pattern="select * from table where and split='s2'")
        ])
        input_config2 = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='s',
                                        pattern='select * from table'),
        ])
        input_config3 = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(
                name='s', pattern='select * from table where span={SPAN}'),
        ])
        static_range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=1,
                                                      end_span_number=2))
        rolling_range_config = range_config_pb2.RangeConfig(
            rolling_range=range_config_pb2.RollingRange(num_spans=2))
        rolling_range_config2 = range_config_pb2.RangeConfig(
            rolling_range=range_config_pb2.RollingRange(num_spans=1,
                                                        start_span_number=1))

        with self.assertRaisesRegexp(
                ValueError, 'Spec setup should the same for all splits'):
            TestInputProcessor(input_config.splits, None)

        with self.assertRaisesRegexp(ValueError,
                                     'Span or Date spec should be specified'):
            TestInputProcessor(input_config2.splits, static_range_config)

        with self.assertRaisesRegexp(
                ValueError,
                'For ExampleGen, start and end span numbers for RangeConfig.StaticRange must be equal'
        ):
            TestInputProcessor(input_config3.splits, static_range_config)

        with self.assertRaisesRegexp(
                ValueError,
                'ExampleGen only support single span for RangeConfig.RollingRange'
        ):
            TestInputProcessor(input_config3.splits, rolling_range_config)

        with self.assertRaisesRegexp(
                ValueError,
                'RangeConfig.rolling_range.start_span_number is not supported'
        ):
            TestInputProcessor(input_config3.splits, rolling_range_config2)
コード例 #3
0
    def testResolveArtifacts(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            artifact1 = self._createExamples(1)
            artifact2 = self._createExamples(2)
            artifact3 = self._createExamples(3)
            artifact4 = self._createExamples(4)
            artifact5 = self._createExamples(5)

            # Test StaticRange.
            resolver = spans_resolver.SpansResolver(
                range_config=range_config_pb2.RangeConfig(
                    static_range=range_config_pb2.StaticRange(
                        start_span_number=2, end_span_number=3)))
            result = resolver.resolve_artifacts(m, {
                'input':
                [artifact1, artifact2, artifact3, artifact4, artifact5]
            })
            self.assertIsNotNone(result)
            self.assertEqual({a.uri
                              for a in result['input']},
                             {artifact3.uri, artifact2.uri})

            # Test RollingRange.
            resolver = spans_resolver.SpansResolver(
                range_config=range_config_pb2.RangeConfig(
                    rolling_range=range_config_pb2.RollingRange(num_spans=3)))
            result = resolver.resolve_artifacts(m, {
                'input':
                [artifact1, artifact2, artifact3, artifact4, artifact5]
            })
            self.assertIsNotNone(result)
            self.assertEqual([a.uri for a in result['input']],
                             [artifact5.uri, artifact4.uri, artifact3.uri])

            # Test RollingRange with start_span_number.
            resolver = spans_resolver.SpansResolver(
                range_config=range_config_pb2.RangeConfig(
                    rolling_range=range_config_pb2.RollingRange(
                        start_span_number=4, num_spans=3)))
            result = resolver.resolve_artifacts(m, {
                'input':
                [artifact1, artifact2, artifact3, artifact4, artifact5]
            })
            self.assertIsNotNone(result)
            self.assertEqual([a.uri for a in result['input']],
                             [artifact5.uri, artifact4.uri])
コード例 #4
0
ファイル: input_processor.py プロジェクト: sycdesign/tfx
  def __init__(self,
               splits: Iterable[example_gen_pb2.Input.Split],
               range_config: Optional[range_config_pb2.RangeConfig] = None):
    """Initialize InputProcessor.

    Args:
      splits: An iterable collection of example_gen_pb2.Input.Split objects.
      range_config: An instance of range_config_pb2.RangeConfig, defines the
        rules for span resolving.
    """
    self._is_match_span = None
    self._is_match_date = None
    self._is_match_version = None
    for split in splits:
      is_match_span, is_match_date, is_match_version = utils.verify_split_pattern_specs(
          split)
      if self._is_match_span is None:
        self._is_match_span = is_match_span
        self._is_match_date = is_match_date
        self._is_match_version = is_match_version
      elif (self._is_match_span != is_match_span or
            self._is_match_date != is_match_date or
            self._is_match_version != is_match_version):
        raise ValueError('Spec setup should the same for all splits: %s.' %
                         split.pattern)

    self._splits = splits

    if (self._is_match_span or self._is_match_date) and not range_config:
      range_config = range_config_pb2.RangeConfig(
          rolling_range=range_config_pb2.RollingRange(num_spans=1))
    if not self._is_match_span and not self._is_match_date and range_config:
      raise ValueError(
          'Span or Date spec should be specified in split pattern if RangeConfig is specified.'
      )

    if range_config:
      if range_config.HasField('static_range'):
        start_span_number = range_config.static_range.start_span_number
        end_span_number = range_config.static_range.end_span_number
        if start_span_number != end_span_number:
          raise ValueError(
              'For ExampleGen, start and end span numbers for RangeConfig.StaticRange must be equal.'
          )
      elif range_config.HasField('rolling_range'):
        if range_config.rolling_range.num_spans != 1:
          raise ValueError(
              'ExampleGen only support single span for RangeConfig.RollingRange.'
          )
        if range_config.rolling_range.start_span_number > 0:
          raise ValueError(
              'RangeConfig.rolling_range.start_span_number is not supported.')
      else:
        raise ValueError('Only static_range and rolling_range are supported.')

    self._range_config = range_config
コード例 #5
0
    def testStrategy_IrMode(self):
        artifact1 = self._createExamples(1)
        artifact2 = self._createExamples(2)
        artifact3 = self._createExamples(3)
        artifact4 = self._createExamples(4)
        artifact5 = self._createExamples(5)

        # Test StaticRange.
        resolver = span_range_strategy.SpanRangeStrategy(
            range_config=range_config_pb2.RangeConfig(
                static_range=range_config_pb2.StaticRange(start_span_number=2,
                                                          end_span_number=3)))
        result = resolver.resolve_artifacts(
            self._store,
            {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]})
        self.assertIsNotNone(result)
        self.assertEqual({a.uri
                          for a in result['input']},
                         {artifact3.uri, artifact2.uri})

        # Test RollingRange.
        resolver = span_range_strategy.SpanRangeStrategy(
            range_config=range_config_pb2.RangeConfig(
                rolling_range=range_config_pb2.RollingRange(num_spans=3)))
        result = resolver.resolve_artifacts(
            self._store,
            {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]})
        self.assertIsNotNone(result)
        self.assertEqual([a.uri for a in result['input']],
                         [artifact5.uri, artifact4.uri, artifact3.uri])

        # Test RollingRange with start_span_number.
        resolver = span_range_strategy.SpanRangeStrategy(
            range_config=range_config_pb2.RangeConfig(
                rolling_range=range_config_pb2.RollingRange(
                    start_span_number=4, num_spans=3)))
        result = resolver.resolve_artifacts(
            self._store,
            {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]})
        self.assertIsNotNone(result)
        self.assertEqual([a.uri for a in result['input']],
                         [artifact5.uri, artifact4.uri])
コード例 #6
0
    def testQueryBasedInputProcessor(self):
        input_config = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='s',
                                        pattern='select * from table'),
        ])
        input_config_span = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(
                name='s1',
                pattern='select * from table where date=@span_yyyymmdd_utc'),
            example_gen_pb2.Input.Split(
                name='s2',
                pattern='select * from table2 where date=@span_yyyymmdd_utc')
        ])

        static_range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=2,
                                                      end_span_number=2))
        rolling_range_config = range_config_pb2.RangeConfig(
            rolling_range=range_config_pb2.RollingRange(num_spans=1))

        with self.assertRaisesRegexp(
                NotImplementedError,
                'For QueryBasedExampleGen, latest Span is not supported'):
            processor = input_processor.QueryBasedInputProcessor(
                input_config_span.splits, rolling_range_config)
            processor.resolve_span_and_version()

        processor = input_processor.QueryBasedInputProcessor(
            input_config.splits, None)
        span, version = processor.resolve_span_and_version()
        fp = processor.get_input_fingerprint(span, version)
        self.assertEqual(span, 0)
        self.assertIsNone(version)
        self.assertIsNone(fp)

        processor = input_processor.QueryBasedInputProcessor(
            input_config_span.splits, static_range_config)
        span, version = processor.resolve_span_and_version()
        fp = processor.get_input_fingerprint(span, version)
        self.assertEqual(span, 2)
        self.assertIsNone(version)
        self.assertIsNone(fp)
        pattern = processor.get_pattern_for_span_version(
            input_config_span.splits[0].pattern, span, version)
        self.assertEqual(pattern, "select * from table where date='19700103'")
コード例 #7
0
    def __init__(self,
                 input_base_uri: Text,
                 splits: Iterable[example_gen_pb2.Input.Split],
                 range_config: Optional[range_config_pb2.RangeConfig] = None):
        """Initialize FileBasedInputProcessor.

    Args:
      input_base_uri: The base path from which files will be searched.
      splits: An iterable collection of example_gen_pb2.Input.Split objects.
      range_config: An instance of range_config_pb2.RangeConfig, defines the
        rules for span resolving.
    """
        super(FileBasedInputProcessor,
              self).__init__(splits=splits, range_config=range_config)

        self._is_match_span = None
        self._is_match_date = None
        self._is_match_version = None
        for split in splits:
            is_match_span, is_match_date, is_match_version = utils.verify_split_pattern_specs(
                split)
            if self._is_match_span is None:
                self._is_match_span = is_match_span
                self._is_match_date = is_match_date
                self._is_match_version = is_match_version
            elif (self._is_match_span != is_match_span
                  or self._is_match_date != is_match_date
                  or self._is_match_version != is_match_version):
                raise ValueError(
                    'Spec setup should the same for all splits: %s.' %
                    split.pattern)

        if (self._is_match_span or self._is_match_date) and not range_config:
            range_config = range_config_pb2.RangeConfig(
                rolling_range=range_config_pb2.RollingRange(num_spans=1))
        if not self._is_match_span and not self._is_match_date and range_config:
            raise ValueError(
                'Span or Date spec should be specified in split pattern if RangeConfig is specified.'
            )

        self._input_base_uri = input_base_uri
        self._fingerprint = None
コード例 #8
0
    def testPenguinPipelineLocalWithRollingWindow(self):
        examplegen_input_config = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='test', pattern='day{SPAN}/*'),
        ])
        resolver_range_config = range_config_pb2.RangeConfig(
            rolling_range=range_config_pb2.RollingRange(num_spans=2))

        def run_pipeline(examplegen_range_config):
            LocalDagRunner().run(
                penguin_pipeline_local._create_pipeline(
                    pipeline_name=self._pipeline_name,
                    data_root=self._data_root_span,
                    module_file=self._module_file,
                    accuracy_threshold=0.1,
                    serving_model_dir=self._serving_model_dir,
                    pipeline_root=self._pipeline_root,
                    metadata_path=self._metadata_path,
                    enable_tuning=False,
                    examplegen_input_config=examplegen_input_config,
                    examplegen_range_config=examplegen_range_config,
                    resolver_range_config=resolver_range_config,
                    beam_pipeline_args=[]))

        # Trigger the pipeline for the first span.
        examplegen_range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=1,
                                                      end_span_number=1))
        run_pipeline(examplegen_range_config)

        self.assertTrue(fileio.exists(self._serving_model_dir))
        self.assertTrue(fileio.exists(self._metadata_path))
        self.assertPipelineExecution(False)
        transform_execution_type = 'tfx.components.transform.component.Transform'
        trainer_execution_type = 'tfx.components.trainer.component.Trainer'
        expected_execution_count = 10  # 8 components + 2 resolver
        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            artifact_count = len(m.store.get_artifacts())
            execution_count = len(m.store.get_executions())
            self.assertGreaterEqual(artifact_count, execution_count)
            self.assertEqual(expected_execution_count, execution_count)
            # Verify Transform's input examples artifacts.
            tft_input_examples_artifacts = self._get_input_examples_artifacts(
                m.store, transform_execution_type)
            self.assertLen(tft_input_examples_artifacts, 1)
            # SpansResolver (controlled by resolver_range_config) returns span 1.
            self.assertEqual(
                '1', tft_input_examples_artifacts[0].custom_properties[
                    utils.SPAN_PROPERTY_NAME].string_value)

        # Trigger the pipeline for the second span.
        examplegen_range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=2,
                                                      end_span_number=2))
        run_pipeline(examplegen_range_config)

        with metadata.Metadata(metadata_config) as m:
            execution_count = len(m.store.get_executions())
            self.assertEqual(expected_execution_count * 2, execution_count)
            # Verify Transform's input examples artifacts.
            tft_input_examples_artifacts = self._get_input_examples_artifacts(
                m.store, transform_execution_type)
            self.assertLen(tft_input_examples_artifacts, 2)
            spans = {
                tft_input_examples_artifacts[0].custom_properties[
                    utils.SPAN_PROPERTY_NAME].string_value,
                tft_input_examples_artifacts[1].custom_properties[
                    utils.SPAN_PROPERTY_NAME].string_value
            }
            # SpansResolver (controlled by resolver_range_config) returns span 1 & 2.
            self.assertSetEqual({'1', '2'}, spans)
            # Verify Trainer's input examples artifacts.
            self.assertLen(
                self._get_input_examples_artifacts(m.store,
                                                   trainer_execution_type), 2)

        # Trigger the pipeline for the thrid span.
        examplegen_range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=3,
                                                      end_span_number=3))
        run_pipeline(examplegen_range_config)

        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            execution_count = len(m.store.get_executions())
            self.assertEqual(expected_execution_count * 3, execution_count)
            # Verify Transform's input examples artifacts.
            tft_input_examples_artifacts = self._get_input_examples_artifacts(
                m.store, transform_execution_type)
            self.assertLen(tft_input_examples_artifacts, 2)
            spans = {
                tft_input_examples_artifacts[0].custom_properties[
                    utils.SPAN_PROPERTY_NAME].string_value,
                tft_input_examples_artifacts[1].custom_properties[
                    utils.SPAN_PROPERTY_NAME].string_value
            }
            # SpansResolver (controlled by resolver_range_config) returns span 2 & 3.
            self.assertSetEqual({'2', '3'}, spans)
            # Verify Trainer's input examples artifacts.
            self.assertLen(
                self._get_input_examples_artifacts(m.store,
                                                   trainer_execution_type), 2)