Beispiel #1
0
    def testRangeConfigSpanWidthPresence(self):
        # Test RangeConfig.static_range behavior when span width is not given.
        span1_split1 = os.path.join(self._input_base_path, 'span01', 'split1',
                                    'data')
        io_utils.write_string_file(span1_split1, 'testing11')

        range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=1,
                                                      end_span_number=1))
        splits1 = [
            example_gen_pb2.Input.Split(name='s1',
                                        pattern='span{SPAN}/split1/*')
        ]

        # RangeConfig cannot find zero padding span without width modifier.
        with self.assertRaisesRegexp(ValueError,
                                     'Cannot find matching for split'):
            utils.calculate_splits_fingerprint_span_and_version(
                self._input_base_path, splits1, range_config=range_config)

        splits2 = [
            example_gen_pb2.Input.Split(name='s1',
                                        pattern='span{SPAN:2}/split1/*')
        ]

        # With width modifier in span spec, RangeConfig.static_range makes
        # correct zero-padded substitution.
        _, span, version = utils.calculate_splits_fingerprint_span_and_version(
            self._input_base_path, splits2, range_config=range_config)
        self.assertEqual(span, 1)
        self.assertIsNone(version)
        self.assertEqual(splits2[0].pattern, 'span01/split1/*')
Beispiel #2
0
    def testFileBasedInputProcessor(self):
        # TODO(b/181275944): migrate test after refactoring FileBasedInputProcessor.

        input_config = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='s1', pattern='path/{SPAN}'),
            example_gen_pb2.Input.Split(name='s2', pattern='path2')
        ])
        input_config2 = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='s', pattern='path'),
        ])

        static_range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=2,
                                                      end_span_number=2))

        with self.assertRaisesRegexp(
                ValueError, 'Spec setup should the same for all splits'):
            input_processor.FileBasedInputProcessor('input_base_uri',
                                                    input_config.splits, None)

        with self.assertRaisesRegexp(ValueError,
                                     'Span or Date spec should be specified'):
            input_processor.FileBasedInputProcessor('input_base_uri',
                                                    input_config2.splits,
                                                    static_range_config)
Beispiel #3
0
    def testInputProcessor(self):
        input_config = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='s', pattern='path/{SPAN}'),
        ])
        static_range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=1,
                                                      end_span_number=2))
        rolling_range_config = range_config_pb2.RangeConfig(
            rolling_range=range_config_pb2.RollingRange(num_spans=2))
        rolling_range_config2 = range_config_pb2.RangeConfig(
            rolling_range=range_config_pb2.RollingRange(num_spans=1,
                                                        start_span_number=1))

        with self.assertRaisesRegexp(
                ValueError,
                'For ExampleGen, start and end span numbers for RangeConfig.StaticRange must be equal'
        ):
            TestInputProcessor(input_config.splits, static_range_config)

        with self.assertRaisesRegexp(
                ValueError,
                'ExampleGen only support single span for RangeConfig.RollingRange'
        ):
            TestInputProcessor(input_config.splits, rolling_range_config)

        with self.assertRaisesRegexp(
                ValueError,
                'RangeConfig.rolling_range.start_span_number is not supported'
        ):
            TestInputProcessor(input_config.splits, rolling_range_config2)
Beispiel #4
0
 def testConstructSubclassQueryBasedWithRangeConfig(self):
     # @span_yyyymmdd_utc will replaced to '19700103' to query, span `2` will be
     # recorded in output Example artifact.
     range_config = range_config_pb2.RangeConfig(
         static_range=range_config_pb2.StaticRange(start_span_number=2,
                                                   end_span_number=2))
     example_gen = TestQueryBasedExampleGenComponent(
         input_config=example_gen_pb2.Input(splits=[
             example_gen_pb2.Input.Split(
                 name='single',
                 pattern='select * from table where date=@span_yyyymmdd_utc'
             ),
         ]),
         range_config=range_config)
     self.assertEqual({}, example_gen.inputs)
     self.assertEqual(driver.QueryBasedDriver, example_gen.driver_class)
     self.assertEqual(
         standard_artifacts.Examples.TYPE_NAME, example_gen.outputs[
             standard_component_specs.EXAMPLES_KEY].type_name)
     stored_range_config = range_config_pb2.RangeConfig()
     proto_utils.json_to_proto(
         example_gen.exec_properties[
             standard_component_specs.RANGE_CONFIG_KEY],
         stored_range_config)
     self.assertEqual(range_config, stored_range_config)
Beispiel #5
0
    def testQueryBasedDriver(self):
        # Create exec proterties.
        exec_properties = {
            standard_component_specs.INPUT_CONFIG_KEY:
            proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='s1',
                        pattern=
                        "select * from table where span={SPAN} and split='s1'"
                    ),
                    example_gen_pb2.Input.Split(
                        name='s2',
                        pattern=
                        "select * from table where span={SPAN} and split='s2'")
                ])),
            standard_component_specs.RANGE_CONFIG_KEY:
            proto_utils.proto_to_json(
                range_config_pb2.RangeConfig(
                    static_range=range_config_pb2.StaticRange(
                        start_span_number=2, end_span_number=2))),
        }
        # Prepare output_dict
        example = standard_artifacts.Examples()
        example.uri = 'my_uri'
        output_dict = {standard_component_specs.EXAMPLES_KEY: [example]}

        query_based_driver = driver.QueryBasedDriver(self._mock_metadata)
        result = query_based_driver.run(
            portable_data_types.ExecutionInfo(output_dict=output_dict,
                                              exec_properties=exec_properties))

        self.assertEqual(exec_properties[utils.SPAN_PROPERTY_NAME], 2)
        self.assertIsNone(exec_properties[utils.VERSION_PROPERTY_NAME])
        self.assertIsNone(exec_properties[utils.FINGERPRINT_PROPERTY_NAME])
        updated_input_config = example_gen_pb2.Input()
        proto_utils.json_to_proto(
            exec_properties[standard_component_specs.INPUT_CONFIG_KEY],
            updated_input_config)
        self.assertProtoEquals(
            """
        splits {
          name: "s1"
          pattern: "select * from table where span=2 and split='s1'"
        }
        splits {
          name: "s2"
          pattern: "select * from table where span=2 and split='s2'"
        }""", updated_input_config)
        self.assertLen(
            result.output_artifacts[
                standard_component_specs.EXAMPLES_KEY].artifacts, 1)
        output_example = result.output_artifacts[
            standard_component_specs.EXAMPLES_KEY].artifacts[0]
        self.assertEqual(output_example.uri, example.uri)
        self.assertEqual(
            output_example.custom_properties[
                utils.SPAN_PROPERTY_NAME].string_value, '2')
Beispiel #6
0
 def testConstructWithStaticRangeConfig(self):
     range_config = range_config_pb2.RangeConfig(
         static_range=range_config_pb2.StaticRange(start_span_number=1,
                                                   end_span_number=1))
     example_gen = component.FileBasedExampleGen(
         input_base='path',
         range_config=range_config,
         custom_executor_spec=executor_spec.ExecutorClassSpec(
             TestExampleGenExecutor))
     stored_range_config = range_config_pb2.RangeConfig()
     json_format.Parse(example_gen.exec_properties['range_config'],
                       stored_range_config)
     self.assertEqual(range_config, stored_range_config)
Beispiel #7
0
    def testInputProcessor(self):
        input_config = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(
                name='s1',
                pattern="select * from table where span={SPAN} and split='s1'"
            ),
            example_gen_pb2.Input.Split(
                name='s2', pattern="select * from table where and split='s2'")
        ])
        input_config2 = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='s',
                                        pattern='select * from table'),
        ])
        input_config3 = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(
                name='s', pattern='select * from table where span={SPAN}'),
        ])
        static_range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=1,
                                                      end_span_number=2))
        rolling_range_config = range_config_pb2.RangeConfig(
            rolling_range=range_config_pb2.RollingRange(num_spans=2))
        rolling_range_config2 = range_config_pb2.RangeConfig(
            rolling_range=range_config_pb2.RollingRange(num_spans=1,
                                                        start_span_number=1))

        with self.assertRaisesRegexp(
                ValueError, 'Spec setup should the same for all splits'):
            TestInputProcessor(input_config.splits, None)

        with self.assertRaisesRegexp(ValueError,
                                     'Span or Date spec should be specified'):
            TestInputProcessor(input_config2.splits, static_range_config)

        with self.assertRaisesRegexp(
                ValueError,
                'For ExampleGen, start and end span numbers for RangeConfig.StaticRange must be equal'
        ):
            TestInputProcessor(input_config3.splits, static_range_config)

        with self.assertRaisesRegexp(
                ValueError,
                'ExampleGen only support single span for RangeConfig.RollingRange'
        ):
            TestInputProcessor(input_config3.splits, rolling_range_config)

        with self.assertRaisesRegexp(
                ValueError,
                'RangeConfig.rolling_range.start_span_number is not supported'
        ):
            TestInputProcessor(input_config3.splits, rolling_range_config2)
Beispiel #8
0
 def testConstructWithStaticRangeConfig(self):
   range_config = range_config_pb2.RangeConfig(
       static_range=range_config_pb2.StaticRange(
           start_span_number=1, end_span_number=1))
   example_gen = component.FileBasedExampleGen(
       input_base='path',
       range_config=range_config,
       custom_executor_spec=executor_spec.BeamExecutorSpec(
           TestExampleGenExecutor))
   stored_range_config = range_config_pb2.RangeConfig()
   proto_utils.json_to_proto(
       example_gen.exec_properties[standard_component_specs.RANGE_CONFIG_KEY],
       stored_range_config)
   self.assertEqual(range_config, stored_range_config)
Beispiel #9
0
  def testRangeConfigWithNonexistentSpan(self):
    # Test behavior when specified span in RangeConfig does not exist.
    span1_split1 = os.path.join(self._input_base_path, 'span01', 'split1',
                                'data')
    io_utils.write_string_file(span1_split1, 'testing11')

    range_config = range_config_pb2.RangeConfig(
        static_range=range_config_pb2.StaticRange(
            start_span_number=2, end_span_number=2))
    splits = [
        example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN:2}/split1/*')
    ]

    with self.assertRaisesRegex(ValueError, 'Cannot find matching for split'):
      utils.calculate_splits_fingerprint_span_and_version(
          self._input_base_path, splits, range_config=range_config)
Beispiel #10
0
    def testResolveArtifacts(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            artifact1 = self._createExamples(1)
            artifact2 = self._createExamples(2)
            artifact3 = self._createExamples(3)
            artifact4 = self._createExamples(4)
            artifact5 = self._createExamples(5)

            # Test StaticRange.
            resolver = spans_resolver.SpansResolver(
                range_config=range_config_pb2.RangeConfig(
                    static_range=range_config_pb2.StaticRange(
                        start_span_number=2, end_span_number=3)))
            result = resolver.resolve_artifacts(m, {
                'input':
                [artifact1, artifact2, artifact3, artifact4, artifact5]
            })
            self.assertIsNotNone(result)
            self.assertEqual({a.uri
                              for a in result['input']},
                             {artifact3.uri, artifact2.uri})

            # Test RollingRange.
            resolver = spans_resolver.SpansResolver(
                range_config=range_config_pb2.RangeConfig(
                    rolling_range=range_config_pb2.RollingRange(num_spans=3)))
            result = resolver.resolve_artifacts(m, {
                'input':
                [artifact1, artifact2, artifact3, artifact4, artifact5]
            })
            self.assertIsNotNone(result)
            self.assertEqual([a.uri for a in result['input']],
                             [artifact5.uri, artifact4.uri, artifact3.uri])

            # Test RollingRange with start_span_number.
            resolver = spans_resolver.SpansResolver(
                range_config=range_config_pb2.RangeConfig(
                    rolling_range=range_config_pb2.RollingRange(
                        start_span_number=4, num_spans=3)))
            result = resolver.resolve_artifacts(m, {
                'input':
                [artifact1, artifact2, artifact3, artifact4, artifact5]
            })
            self.assertIsNotNone(result)
            self.assertEqual([a.uri for a in result['input']],
                             [artifact5.uri, artifact4.uri])
Beispiel #11
0
    def testQueryBasedInputProcessor(self):
        input_config = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='s',
                                        pattern='select * from table'),
        ])
        input_config_span = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(
                name='s1',
                pattern='select * from table where date=@span_yyyymmdd_utc'),
            example_gen_pb2.Input.Split(
                name='s2',
                pattern='select * from table2 where date=@span_yyyymmdd_utc')
        ])

        static_range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=2,
                                                      end_span_number=2))
        rolling_range_config = range_config_pb2.RangeConfig(
            rolling_range=range_config_pb2.RollingRange(num_spans=1))

        with self.assertRaisesRegexp(
                NotImplementedError,
                'For QueryBasedExampleGen, latest Span is not supported'):
            processor = input_processor.QueryBasedInputProcessor(
                input_config_span.splits, rolling_range_config)
            processor.resolve_span_and_version()

        processor = input_processor.QueryBasedInputProcessor(
            input_config.splits, None)
        span, version = processor.resolve_span_and_version()
        fp = processor.get_input_fingerprint(span, version)
        self.assertEqual(span, 0)
        self.assertIsNone(version)
        self.assertIsNone(fp)

        processor = input_processor.QueryBasedInputProcessor(
            input_config_span.splits, static_range_config)
        span, version = processor.resolve_span_and_version()
        fp = processor.get_input_fingerprint(span, version)
        self.assertEqual(span, 2)
        self.assertIsNone(version)
        self.assertIsNone(fp)
        pattern = processor.get_pattern_for_span_version(
            input_config_span.splits[0].pattern, span, version)
        self.assertEqual(pattern, "select * from table where date='19700103'")
Beispiel #12
0
 def testConstructWithRangeConfig(self):
     range_config = range_config_pb2.RangeConfig(
         static_range=range_config_pb2.StaticRange(start_span_number=2,
                                                   end_span_number=2))
     # @span_yyyymmdd_utc will replaced to '19700103' to query, span `2` will be
     # recorded in output Example artifact.
     big_query_example_gen = component.BigQueryExampleGen(
         query='select * from table where date=@span_yyyymmdd_utc',
         range_config=range_config)
     self.assertEqual(
         standard_artifacts.Examples.TYPE_NAME,
         big_query_example_gen.outputs[
             standard_component_specs.EXAMPLES_KEY].type_name)
     stored_range_config = range_config_pb2.RangeConfig()
     proto_utils.json_to_proto(
         big_query_example_gen.exec_properties[
             standard_component_specs.RANGE_CONFIG_KEY],
         stored_range_config)
     self.assertEqual(range_config, stored_range_config)
Beispiel #13
0
    def testStrategy_IrMode(self):
        artifact1 = self._createExamples(1)
        artifact2 = self._createExamples(2)
        artifact3 = self._createExamples(3)
        artifact4 = self._createExamples(4)
        artifact5 = self._createExamples(5)

        # Test StaticRange.
        resolver = span_range_strategy.SpanRangeStrategy(
            range_config=range_config_pb2.RangeConfig(
                static_range=range_config_pb2.StaticRange(start_span_number=2,
                                                          end_span_number=3)))
        result = resolver.resolve_artifacts(
            self._store,
            {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]})
        self.assertIsNotNone(result)
        self.assertEqual({a.uri
                          for a in result['input']},
                         {artifact3.uri, artifact2.uri})

        # Test RollingRange.
        resolver = span_range_strategy.SpanRangeStrategy(
            range_config=range_config_pb2.RangeConfig(
                rolling_range=range_config_pb2.RollingRange(num_spans=3)))
        result = resolver.resolve_artifacts(
            self._store,
            {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]})
        self.assertIsNotNone(result)
        self.assertEqual([a.uri for a in result['input']],
                         [artifact5.uri, artifact4.uri, artifact3.uri])

        # Test RollingRange with start_span_number.
        resolver = span_range_strategy.SpanRangeStrategy(
            range_config=range_config_pb2.RangeConfig(
                rolling_range=range_config_pb2.RollingRange(
                    start_span_number=4, num_spans=3)))
        result = resolver.resolve_artifacts(
            self._store,
            {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]})
        self.assertIsNotNone(result)
        self.assertEqual([a.uri for a in result['input']],
                         [artifact5.uri, artifact4.uri])
Beispiel #14
0
    def testResolve(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            contexts = m.register_pipeline_contexts_if_not_exists(
                self._pipeline_info)
            artifact_one = standard_artifacts.Examples()
            artifact_one.uri = 'uri_one'
            artifact_one.set_string_custom_property(utils.SPAN_PROPERTY_NAME,
                                                    '1')
            m.publish_artifacts([artifact_one])
            artifact_two = standard_artifacts.Examples()
            artifact_two.uri = 'uri_two'
            artifact_two.set_string_custom_property(utils.SPAN_PROPERTY_NAME,
                                                    '2')
            m.register_execution(exec_properties={},
                                 pipeline_info=self._pipeline_info,
                                 component_info=self._component_info,
                                 contexts=contexts)
            m.publish_execution(
                component_info=self._component_info,
                output_artifacts={'key': [artifact_one, artifact_two]})

            resolver = spans_resolver.SpansResolver(
                range_config=range_config_pb2.RangeConfig(
                    static_range=range_config_pb2.StaticRange(
                        start_span_number=1, end_span_number=1)))
            resolve_result = resolver.resolve(
                pipeline_info=self._pipeline_info,
                metadata_handler=m,
                source_channels={
                    'input':
                    types.Channel(type=artifact_one.type,
                                  producer_component_id=self._component_info.
                                  component_id,
                                  output_key='key')
                })

            self.assertTrue(resolve_result.has_complete_result)
            self.assertEqual([
                artifact.uri
                for artifact in resolve_result.per_key_resolve_result['input']
            ], [artifact_one.uri])
            self.assertTrue(resolve_result.per_key_resolve_state['input'])
Beispiel #15
0
class FactoryTest(tf.test.TestCase, parameterized.TestCase):

  def class_path_exists(self, class_path):
    module_name, unused_class_name = class_path.rsplit('.', maxsplit=1)
    try:
      importlib.import_module(module_name)
    except ImportError:
      return False
    else:
      return True

  @parameterized.parameters(
      ('tfx.dsl.resolvers.oldest_artifacts_resolver'
       '.OldestArtifactsResolver',
       '{}'),
      ('tfx.dsl.resolvers.unprocessed_artifacts_resolver'
       '.UnprocessedArtifactsResolver',
       '{"execution_type_name": "Foo"}'),
      ('tfx.dsl.input_resolution.strategies.latest_artifact_strategy'
       '.LatestArtifactStrategy',
       '{}'),
      ('tfx.dsl.input_resolution.strategies.latest_blessed_model_strategy'
       '.LatestBlessedModelStrategy',
       '{}'),
      ('tfx.dsl.input_resolution.strategies.span_range_strategy'
       '.SpanRangeStrategy',
       json_utils.dumps({
           'range_config': range_config_pb2.StaticRange(
               start_span_number=1, end_span_number=10)
       })),
      )
  def test_make_resolver_strategy_instance(self, class_path, config_json):
    if not self.class_path_exists(class_path):
      self.skipTest(f"Class path {class_path} doesn't exist.")
    resolver_step = pipeline_pb2.ResolverConfig.ResolverStep(
        class_path=class_path,
        config_json=config_json)

    result = factory.make_resolver_strategy_instance(resolver_step)

    self.assertIsInstance(result, resolver.ResolverStrategy)
    self.assertEndsWith(class_path, result.__class__.__name__)
Beispiel #16
0
    def testRangeConfigWithDateSpec(self):
        span1_split1 = os.path.join(self._input_base_path, '19700102',
                                    'split1', 'data')
        io_utils.write_string_file(span1_split1, 'testing11')

        start_span = utils.date_to_span_number(1970, 1, 2)
        end_span = utils.date_to_span_number(1970, 1, 2)
        range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(
                start_span_number=start_span, end_span_number=end_span))

        splits = [
            example_gen_pb2.Input.Split(name='s1',
                                        pattern='{YYYY}{MM}{DD}/split1/*')
        ]
        _, span, version = utils.calculate_splits_fingerprint_span_and_version(
            self._input_base_path, splits, range_config=range_config)

        self.assertEqual(span, 1)
        self.assertIsNone(version)
        self.assertEqual(splits[0].pattern, '19700102/split1/*')
Beispiel #17
0
  def testSpanAlignWithRangeConfig(self):
    span1_split1 = os.path.join(self._input_base_path, 'span01', 'split1',
                                'data')
    io_utils.write_string_file(span1_split1, 'testing11')
    span2_split1 = os.path.join(self._input_base_path, 'span02', 'split1',
                                'data')
    io_utils.write_string_file(span2_split1, 'testing21')

    # Test static range in RangeConfig.
    range_config = range_config_pb2.RangeConfig(
        static_range=range_config_pb2.StaticRange(
            start_span_number=1, end_span_number=1))
    splits = [
        example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN:2}/split1/*')
    ]

    _, span, version = utils.calculate_splits_fingerprint_span_and_version(
        self._input_base_path, splits, range_config)
    self.assertEqual(span, 1)
    self.assertIsNone(version)
    self.assertEqual(splits[0].pattern, 'span01/split1/*')
    def testPenguinPipelineLocalWithRollingWindow(self):
        examplegen_input_config = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='test', pattern='day{SPAN}/*'),
        ])
        resolver_range_config = range_config_pb2.RangeConfig(
            rolling_range=range_config_pb2.RollingRange(num_spans=2))

        def run_pipeline(examplegen_range_config):
            LocalDagRunner().run(
                penguin_pipeline_local._create_pipeline(
                    pipeline_name=self._pipeline_name,
                    data_root=self._data_root_span,
                    module_file=self._module_file,
                    accuracy_threshold=0.1,
                    serving_model_dir=self._serving_model_dir,
                    pipeline_root=self._pipeline_root,
                    metadata_path=self._metadata_path,
                    enable_tuning=False,
                    examplegen_input_config=examplegen_input_config,
                    examplegen_range_config=examplegen_range_config,
                    resolver_range_config=resolver_range_config,
                    beam_pipeline_args=[]))

        # Trigger the pipeline for the first span.
        examplegen_range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=1,
                                                      end_span_number=1))
        run_pipeline(examplegen_range_config)

        self.assertTrue(fileio.exists(self._serving_model_dir))
        self.assertTrue(fileio.exists(self._metadata_path))
        self.assertPipelineExecution(False)
        transform_execution_type = 'tfx.components.transform.component.Transform'
        trainer_execution_type = 'tfx.components.trainer.component.Trainer'
        expected_execution_count = 10  # 8 components + 2 resolver
        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            artifact_count = len(m.store.get_artifacts())
            execution_count = len(m.store.get_executions())
            self.assertGreaterEqual(artifact_count, execution_count)
            self.assertEqual(expected_execution_count, execution_count)
            # Verify Transform's input examples artifacts.
            tft_input_examples_artifacts = self._get_input_examples_artifacts(
                m.store, transform_execution_type)
            self.assertLen(tft_input_examples_artifacts, 1)
            # SpansResolver (controlled by resolver_range_config) returns span 1.
            self.assertEqual(
                '1', tft_input_examples_artifacts[0].custom_properties[
                    utils.SPAN_PROPERTY_NAME].string_value)

        # Trigger the pipeline for the second span.
        examplegen_range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=2,
                                                      end_span_number=2))
        run_pipeline(examplegen_range_config)

        with metadata.Metadata(metadata_config) as m:
            execution_count = len(m.store.get_executions())
            self.assertEqual(expected_execution_count * 2, execution_count)
            # Verify Transform's input examples artifacts.
            tft_input_examples_artifacts = self._get_input_examples_artifacts(
                m.store, transform_execution_type)
            self.assertLen(tft_input_examples_artifacts, 2)
            spans = {
                tft_input_examples_artifacts[0].custom_properties[
                    utils.SPAN_PROPERTY_NAME].string_value,
                tft_input_examples_artifacts[1].custom_properties[
                    utils.SPAN_PROPERTY_NAME].string_value
            }
            # SpansResolver (controlled by resolver_range_config) returns span 1 & 2.
            self.assertSetEqual({'1', '2'}, spans)
            # Verify Trainer's input examples artifacts.
            self.assertLen(
                self._get_input_examples_artifacts(m.store,
                                                   trainer_execution_type), 2)

        # Trigger the pipeline for the thrid span.
        examplegen_range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=3,
                                                      end_span_number=3))
        run_pipeline(examplegen_range_config)

        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            execution_count = len(m.store.get_executions())
            self.assertEqual(expected_execution_count * 3, execution_count)
            # Verify Transform's input examples artifacts.
            tft_input_examples_artifacts = self._get_input_examples_artifacts(
                m.store, transform_execution_type)
            self.assertLen(tft_input_examples_artifacts, 2)
            spans = {
                tft_input_examples_artifacts[0].custom_properties[
                    utils.SPAN_PROPERTY_NAME].string_value,
                tft_input_examples_artifacts[1].custom_properties[
                    utils.SPAN_PROPERTY_NAME].string_value
            }
            # SpansResolver (controlled by resolver_range_config) returns span 2 & 3.
            self.assertSetEqual({'2', '3'}, spans)
            # Verify Trainer's input examples artifacts.
            self.assertLen(
                self._get_input_examples_artifacts(m.store,
                                                   trainer_execution_type), 2)
Beispiel #19
0
    def testResolveExecProperties(self):
        # Create input dir.
        self._input_base_path = os.path.join(self._test_dir, 'input_base')
        fileio.makedirs(self._input_base_path)

        # Create exec proterties.
        self._exec_properties = {
            standard_component_specs.INPUT_BASE_KEY:
            self._input_base_path,
            standard_component_specs.INPUT_CONFIG_KEY:
            proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='s1',
                        pattern='span{SPAN:2}/version{VERSION:2}/split1/*'),
                    example_gen_pb2.Input.Split(
                        name='s2',
                        pattern='span{SPAN:2}/version{VERSION:2}/split2/*')
                ])),
            standard_component_specs.RANGE_CONFIG_KEY:
            None,
        }

        # Test align of span number.
        span1_v1_split1 = os.path.join(self._input_base_path, 'span01',
                                       'version01', 'split1', 'data')
        io_utils.write_string_file(span1_v1_split1, 'testing11')
        span1_v1_split2 = os.path.join(self._input_base_path, 'span01',
                                       'version01', 'split2', 'data')
        io_utils.write_string_file(span1_v1_split2, 'testing12')
        span2_v1_split1 = os.path.join(self._input_base_path, 'span02',
                                       'version01', 'split1', 'data')
        io_utils.write_string_file(span2_v1_split1, 'testing21')

        # Check that error raised when span does not match.
        with self.assertRaisesRegexp(
                ValueError, 'Latest span should be the same for each split'):
            self._file_based_driver.resolve_exec_properties(
                self._exec_properties, None, None)

        span2_v1_split2 = os.path.join(self._input_base_path, 'span02',
                                       'version01', 'split2', 'data')
        io_utils.write_string_file(span2_v1_split2, 'testing22')
        span2_v2_split1 = os.path.join(self._input_base_path, 'span02',
                                       'version02', 'split1', 'data')
        io_utils.write_string_file(span2_v2_split1, 'testing21')

        # Check that error raised when span matches, but version does not match.
        with self.assertRaisesRegexp(
                ValueError,
                'Latest version should be the same for each split'):
            self._file_based_driver.resolve_exec_properties(
                self._exec_properties, None, None)

        span2_v2_split2 = os.path.join(self._input_base_path, 'span02',
                                       'version02', 'split2', 'data')
        io_utils.write_string_file(span2_v2_split2, 'testing22')

        # Test if latest span and version selected when span and version aligns
        # for each split.
        self._file_based_driver.resolve_exec_properties(
            self._exec_properties, None, None)
        self.assertEqual(self._exec_properties[utils.SPAN_PROPERTY_NAME], 2)
        self.assertEqual(self._exec_properties[utils.VERSION_PROPERTY_NAME], 2)
        self.assertRegex(
            self._exec_properties[utils.FINGERPRINT_PROPERTY_NAME],
            r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*'
        )
        updated_input_config = example_gen_pb2.Input()
        proto_utils.json_to_proto(
            self._exec_properties[standard_component_specs.INPUT_CONFIG_KEY],
            updated_input_config)

        # Check if latest span is selected.
        self.assertProtoEquals(
            """
        splits {
          name: "s1"
          pattern: "span02/version02/split1/*"
        }
        splits {
          name: "s2"
          pattern: "span02/version02/split2/*"
        }""", updated_input_config)

        # Test driver behavior using RangeConfig with static range.
        self._exec_properties[
            standard_component_specs.
            INPUT_CONFIG_KEY] = proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='s1',
                        pattern='span{SPAN:2}/version{VERSION:2}/split1/*'),
                    example_gen_pb2.Input.Split(
                        name='s2',
                        pattern='span{SPAN:2}/version{VERSION:2}/split2/*'),
                ]))

        self._exec_properties[
            standard_component_specs.
            RANGE_CONFIG_KEY] = proto_utils.proto_to_json(
                range_config_pb2.RangeConfig(
                    static_range=range_config_pb2.StaticRange(
                        start_span_number=1, end_span_number=2)))
        with self.assertRaisesRegexp(
                ValueError, 'For ExampleGen, start and end span numbers'):
            self._file_based_driver.resolve_exec_properties(
                self._exec_properties, None, None)

        self._exec_properties[
            standard_component_specs.
            RANGE_CONFIG_KEY] = proto_utils.proto_to_json(
                range_config_pb2.RangeConfig(
                    static_range=range_config_pb2.StaticRange(
                        start_span_number=1, end_span_number=1)))
        self._file_based_driver.resolve_exec_properties(
            self._exec_properties, None, None)
        self.assertEqual(self._exec_properties[utils.SPAN_PROPERTY_NAME], 1)
        self.assertEqual(self._exec_properties[utils.VERSION_PROPERTY_NAME], 1)
        self.assertRegex(
            self._exec_properties[utils.FINGERPRINT_PROPERTY_NAME],
            r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*'
        )
        updated_input_config = example_gen_pb2.Input()
        proto_utils.json_to_proto(
            self._exec_properties[standard_component_specs.INPUT_CONFIG_KEY],
            updated_input_config)
        # Check if correct span inside static range is selected.
        self.assertProtoEquals(
            """
        splits {
          name: "s1"
          pattern: "span01/version01/split1/*"
        }
        splits {
          name: "s2"
          pattern: "span01/version01/split2/*"
        }""", updated_input_config)