예제 #1
0
    def testResolveArtifacts(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            artifact1 = self._createExamples('1')
            artifact2 = self._createExamples('2')
            artifact3 = self._createExamples('3')
            artifact4 = self._createExamples('4')
            artifact5 = self._createExamples('5')

            # Test StaticRange.
            resolver = spans_resolver.SpansResolver(
                range_config=range_config_pb2.RangeConfig(
                    static_range=range_config_pb2.StaticRange(
                        start_span_number=2, end_span_number=3)))
            result = resolver.resolve_artifacts(m, {
                'input':
                [artifact1, artifact2, artifact3, artifact4, artifact5]
            })
            self.assertIsNotNone(result)
            self.assertEqual([a.uri for a in result['input']],
                             [artifact3.uri, artifact2.uri])

            # Test RollingRange.
            resolver = spans_resolver.SpansResolver(
                range_config=range_config_pb2.RangeConfig(
                    rolling_range=range_config_pb2.RollingRange(num_spans=3)))
            result = resolver.resolve_artifacts(m, {
                'input':
                [artifact1, artifact2, artifact3, artifact4, artifact5]
            })
            self.assertIsNotNone(result)
            self.assertEqual([a.uri for a in result['input']],
                             [artifact5.uri, artifact4.uri, artifact3.uri])

            # Test RollingRange with start_span_number.
            resolver = spans_resolver.SpansResolver(
                range_config=range_config_pb2.RangeConfig(
                    rolling_range=range_config_pb2.RollingRange(
                        start_span_number=4, num_spans=3)))
            result = resolver.resolve_artifacts(m, {
                'input':
                [artifact1, artifact2, artifact3, artifact4, artifact5]
            })
            self.assertIsNotNone(result)
            self.assertEqual([a.uri for a in result['input']],
                             [artifact5.uri, artifact4.uri])
예제 #2
0
    def testQueryBasedInputProcessor(self):
        input_config = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='s',
                                        pattern='select * from table'),
        ])
        input_config_span = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(
                name='s1',
                pattern='select * from table where date=@span_yyyymmdd_utc'),
            example_gen_pb2.Input.Split(
                name='s2',
                pattern='select * from table2 where date=@span_yyyymmdd_utc')
        ])

        static_range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=2,
                                                      end_span_number=2))
        rolling_range_config = range_config_pb2.RangeConfig(
            rolling_range=range_config_pb2.RollingRange(num_spans=1))

        with self.assertRaisesRegexp(
                NotImplementedError,
                'For QueryBasedExampleGen, latest Span is not supported'):
            processor = input_processor.QueryBasedInputProcessor(
                input_config_span.splits, rolling_range_config)
            processor.resolve_span_and_version()

        processor = input_processor.QueryBasedInputProcessor(
            input_config.splits, None)
        span, version = processor.resolve_span_and_version()
        fp = processor.get_input_fingerprint(span, version)
        self.assertEqual(span, 0)
        self.assertIsNone(version)
        self.assertIsNone(fp)

        processor = input_processor.QueryBasedInputProcessor(
            input_config_span.splits, static_range_config)
        span, version = processor.resolve_span_and_version()
        fp = processor.get_input_fingerprint(span, version)
        self.assertEqual(span, 2)
        self.assertIsNone(version)
        self.assertIsNone(fp)
        pattern = processor.get_pattern_for_span_version(
            input_config_span.splits[0].pattern, span, version)
        self.assertEqual(pattern, "select * from table where date='19700103'")
예제 #3
0
    def testQueryBasedDriver(self):
        # Create exec proterties.
        exec_properties = {
            standard_component_specs.INPUT_CONFIG_KEY:
            proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='s1',
                        pattern=
                        'select * from table where date=@span_yyyymmdd_utc'),
                    example_gen_pb2.Input.Split(
                        name='s2',
                        pattern=
                        'select * from table2 where date=@span_yyyymmdd_utc')
                ])),
            standard_component_specs.RANGE_CONFIG_KEY:
            proto_utils.proto_to_json(
                range_config_pb2.RangeConfig(
                    static_range=range_config_pb2.StaticRange(
                        start_span_number=2, end_span_number=2))),
        }
        # Prepare output_dict
        example = standard_artifacts.Examples()
        example.uri = 'my_uri'
        output_dict = {standard_component_specs.EXAMPLES_KEY: [example]}

        query_based_driver = driver.QueryBasedDriver(self._mock_metadata)
        result = query_based_driver.run(
            portable_data_types.ExecutionInfo(output_dict=output_dict,
                                              exec_properties=exec_properties))

        self.assertEqual(exec_properties[utils.SPAN_PROPERTY_NAME], 2)
        self.assertIsNone(exec_properties[utils.VERSION_PROPERTY_NAME])
        self.assertIsNone(exec_properties[utils.FINGERPRINT_PROPERTY_NAME])
        updated_input_config = example_gen_pb2.Input()
        proto_utils.json_to_proto(
            exec_properties[standard_component_specs.INPUT_CONFIG_KEY],
            updated_input_config)
        self.assertProtoEquals(
            """
        splits {
          name: "s1"
          pattern: "select * from table where date='19700103'"
        }
        splits {
          name: "s2"
          pattern: "select * from table2 where date='19700103'"
        }""", updated_input_config)
        self.assertLen(
            result.output_artifacts[
                standard_component_specs.EXAMPLES_KEY].artifacts, 1)
        output_example = result.output_artifacts[
            standard_component_specs.EXAMPLES_KEY].artifacts[0]
        self.assertEqual(output_example.uri, example.uri)
        self.assertEqual(
            output_example.custom_properties[
                utils.SPAN_PROPERTY_NAME].int_value, 2)
예제 #4
0
  def __init__(self,
               splits: Iterable[example_gen_pb2.Input.Split],
               range_config: Optional[range_config_pb2.RangeConfig] = None):
    """Initialize InputProcessor.

    Args:
      splits: An iterable collection of example_gen_pb2.Input.Split objects.
      range_config: An instance of range_config_pb2.RangeConfig, defines the
        rules for span resolving.
    """
    self._is_match_span = None
    self._is_match_date = None
    self._is_match_version = None
    for split in splits:
      is_match_span, is_match_date, is_match_version = utils.verify_split_pattern_specs(
          split)
      if self._is_match_span is None:
        self._is_match_span = is_match_span
        self._is_match_date = is_match_date
        self._is_match_version = is_match_version
      elif (self._is_match_span != is_match_span or
            self._is_match_date != is_match_date or
            self._is_match_version != is_match_version):
        raise ValueError('Spec setup should the same for all splits: %s.' %
                         split.pattern)

    self._splits = splits

    if (self._is_match_span or self._is_match_date) and not range_config:
      range_config = range_config_pb2.RangeConfig(
          rolling_range=range_config_pb2.RollingRange(num_spans=1))
    if not self._is_match_span and not self._is_match_date and range_config:
      raise ValueError(
          'Span or Date spec should be specified in split pattern if RangeConfig is specified.'
      )

    if range_config:
      if range_config.HasField('static_range'):
        start_span_number = range_config.static_range.start_span_number
        end_span_number = range_config.static_range.end_span_number
        if start_span_number != end_span_number:
          raise ValueError(
              'For ExampleGen, start and end span numbers for RangeConfig.StaticRange must be equal.'
          )
      elif range_config.HasField('rolling_range'):
        if range_config.rolling_range.num_spans != 1:
          raise ValueError(
              'ExampleGen only support single span for RangeConfig.RollingRange.'
          )
        if range_config.rolling_range.start_span_number > 0:
          raise ValueError(
              'RangeConfig.rolling_range.start_span_number is not supported.')
      else:
        raise ValueError('Only static_range and rolling_range are supported.')

    self._range_config = range_config
예제 #5
0
 def testConstructWithRangeConfig(self):
     range_config = range_config_pb2.RangeConfig(
         static_range=range_config_pb2.StaticRange(start_span_number=2,
                                                   end_span_number=2))
     # @span_yyyymmdd_utc will replaced to '19700103' to query, span `2` will be
     # recorded in output Example artifact.
     big_query_example_gen = component.BigQueryExampleGen(
         query='select * from table where date=@span_yyyymmdd_utc',
         range_config=range_config)
     self.assertEqual(
         standard_artifacts.Examples.TYPE_NAME,
         big_query_example_gen.outputs[
             standard_component_specs.EXAMPLES_KEY].type_name)
     stored_range_config = range_config_pb2.RangeConfig()
     proto_utils.json_to_proto(
         big_query_example_gen.exec_properties[
             standard_component_specs.RANGE_CONFIG_KEY],
         stored_range_config)
     self.assertEqual(range_config, stored_range_config)
예제 #6
0
    def testStrategy(self):
        artifact1 = self._createExamples(1)
        artifact2 = self._createExamples(2)
        artifact3 = self._createExamples(3)
        artifact4 = self._createExamples(4)
        artifact5 = self._createExamples(5)

        # Test StaticRange.
        resolver = span_range_strategy.SpanRangeStrategy(
            range_config=range_config_pb2.RangeConfig(
                static_range=range_config_pb2.StaticRange(start_span_number=2,
                                                          end_span_number=3)))
        result = resolver.resolve_artifacts(
            self._store,
            {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]})
        self.assertIsNotNone(result)
        self.assertEqual({a.uri
                          for a in result['input']},
                         {artifact3.uri, artifact2.uri})

        # Test RollingRange.
        resolver = span_range_strategy.SpanRangeStrategy(
            range_config=range_config_pb2.RangeConfig(
                rolling_range=range_config_pb2.RollingRange(num_spans=3)))
        result = resolver.resolve_artifacts(
            self._store,
            {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]})
        self.assertIsNotNone(result)
        self.assertEqual([a.uri for a in result['input']],
                         [artifact5.uri, artifact4.uri, artifact3.uri])

        # Test RollingRange with start_span_number.
        resolver = span_range_strategy.SpanRangeStrategy(
            range_config=range_config_pb2.RangeConfig(
                rolling_range=range_config_pb2.RollingRange(
                    start_span_number=4, num_spans=3)))
        result = resolver.resolve_artifacts(
            self._store,
            {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]})
        self.assertIsNotNone(result)
        self.assertEqual([a.uri for a in result['input']],
                         [artifact5.uri, artifact4.uri])
예제 #7
0
  def testRangeConfigWithNonexistentSpan(self):
    # Test behavior when specified span in RangeConfig does not exist.
    span1_split1 = os.path.join(self._input_base_path, 'span01', 'split1',
                                'data')
    io_utils.write_string_file(span1_split1, 'testing11')

    range_config = range_config_pb2.RangeConfig(
        static_range=range_config_pb2.StaticRange(
            start_span_number=2, end_span_number=2))
    splits = [
        example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN:2}/split1/*')
    ]

    with self.assertRaisesRegex(ValueError, 'Cannot find matching for split'):
      utils.calculate_splits_fingerprint_span_and_version(
          self._input_base_path, splits, range_config=range_config)
예제 #8
0
    def resolve_exec_properties(
        self,
        exec_properties: Dict[Text, Any],
        pipeline_info: data_types.PipelineInfo,
        component_info: data_types.ComponentInfo,
    ) -> Dict[Text, Any]:
        """Overrides BaseDriver.resolve_exec_properties()."""
        del pipeline_info, component_info

        input_config = example_gen_pb2.Input()
        proto_utils.json_to_proto(
            exec_properties[standard_component_specs.INPUT_CONFIG_KEY],
            input_config)

        input_base = exec_properties[standard_component_specs.INPUT_BASE_KEY]
        logging.debug('Processing input %s.', input_base)

        range_config = None
        range_config_entry = exec_properties.get(
            standard_component_specs.RANGE_CONFIG_KEY)
        if range_config_entry:
            range_config = range_config_pb2.RangeConfig()
            proto_utils.json_to_proto(range_config_entry, range_config)

            if range_config.HasField('static_range'):
                # For ExampleGen, StaticRange must specify an exact span to look for,
                # since only one span is processed at a time.
                start_span_number = range_config.static_range.start_span_number
                end_span_number = range_config.static_range.end_span_number
                if start_span_number != end_span_number:
                    raise ValueError(
                        'Start and end span numbers for RangeConfig.static_range must '
                        'be equal: (%s, %s)' %
                        (start_span_number, end_span_number))

        # Note that this function updates the input_config.splits.pattern.
        fingerprint, span, version = utils.calculate_splits_fingerprint_span_and_version(
            input_base, input_config.splits, range_config)

        exec_properties[standard_component_specs.
                        INPUT_CONFIG_KEY] = proto_utils.proto_to_json(
                            input_config)
        exec_properties[utils.SPAN_PROPERTY_NAME] = span
        exec_properties[utils.VERSION_PROPERTY_NAME] = version
        exec_properties[utils.FINGERPRINT_PROPERTY_NAME] = fingerprint

        return exec_properties
예제 #9
0
    def resolve_exec_properties(
        self,
        exec_properties: Dict[Text, Any],
        pipeline_info: data_types.PipelineInfo,
        component_info: data_types.ComponentInfo,
    ) -> Dict[Text, Any]:
        """Overrides BaseDriver.resolve_exec_properties()."""
        del pipeline_info, component_info

        input_config = example_gen_pb2.Input()
        proto_utils.json_to_proto(
            exec_properties[standard_component_specs.INPUT_CONFIG_KEY],
            input_config)

        input_base = exec_properties.get(
            standard_component_specs.INPUT_BASE_KEY)
        logging.debug('Processing input %s.', input_base)

        range_config = None
        range_config_entry = exec_properties.get(
            standard_component_specs.RANGE_CONFIG_KEY)
        if range_config_entry:
            range_config = range_config_pb2.RangeConfig()
            proto_utils.json_to_proto(range_config_entry, range_config)

        processor = self.get_input_processor(splits=input_config.splits,
                                             range_config=range_config,
                                             input_base_uri=input_base)

        span, version = processor.resolve_span_and_version()
        fingerprint = processor.get_input_fingerprint(span, version)

        # Updates the input_config.splits.pattern.
        for split in input_config.splits:
            split.pattern = processor.get_pattern_for_span_version(
                split.pattern, span, version)

        exec_properties[standard_component_specs.
                        INPUT_CONFIG_KEY] = proto_utils.proto_to_json(
                            input_config)
        exec_properties[utils.SPAN_PROPERTY_NAME] = span
        exec_properties[utils.VERSION_PROPERTY_NAME] = version
        exec_properties[utils.FINGERPRINT_PROPERTY_NAME] = fingerprint

        return exec_properties
예제 #10
0
    def testResolve(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            contexts = m.register_pipeline_contexts_if_not_exists(
                self._pipeline_info)
            artifact_one = standard_artifacts.Examples()
            artifact_one.uri = 'uri_one'
            artifact_one.set_string_custom_property(utils.SPAN_PROPERTY_NAME,
                                                    '1')
            m.publish_artifacts([artifact_one])
            artifact_two = standard_artifacts.Examples()
            artifact_two.uri = 'uri_two'
            artifact_two.set_string_custom_property(utils.SPAN_PROPERTY_NAME,
                                                    '2')
            m.register_execution(exec_properties={},
                                 pipeline_info=self._pipeline_info,
                                 component_info=self._component_info,
                                 contexts=contexts)
            m.publish_execution(
                component_info=self._component_info,
                output_artifacts={'key': [artifact_one, artifact_two]})

            resolver = spans_resolver.SpansResolver(
                range_config=range_config_pb2.RangeConfig(
                    static_range=range_config_pb2.StaticRange(
                        start_span_number=1, end_span_number=1)))
            resolve_result = resolver.resolve(
                pipeline_info=self._pipeline_info,
                metadata_handler=m,
                source_channels={
                    'input':
                    types.Channel(type=artifact_one.type,
                                  producer_component_id=self._component_info.
                                  component_id,
                                  output_key='key')
                })

            self.assertTrue(resolve_result.has_complete_result)
            self.assertEqual([
                artifact.uri
                for artifact in resolve_result.per_key_resolve_result['input']
            ], [artifact_one.uri])
            self.assertTrue(resolve_result.per_key_resolve_state['input'])
예제 #11
0
    def __init__(self,
                 input_base_uri: Text,
                 splits: Iterable[example_gen_pb2.Input.Split],
                 range_config: Optional[range_config_pb2.RangeConfig] = None):
        """Initialize FileBasedInputProcessor.

    Args:
      input_base_uri: The base path from which files will be searched.
      splits: An iterable collection of example_gen_pb2.Input.Split objects.
      range_config: An instance of range_config_pb2.RangeConfig, defines the
        rules for span resolving.
    """
        super(FileBasedInputProcessor,
              self).__init__(splits=splits, range_config=range_config)

        self._is_match_span = None
        self._is_match_date = None
        self._is_match_version = None
        for split in splits:
            is_match_span, is_match_date, is_match_version = utils.verify_split_pattern_specs(
                split)
            if self._is_match_span is None:
                self._is_match_span = is_match_span
                self._is_match_date = is_match_date
                self._is_match_version = is_match_version
            elif (self._is_match_span != is_match_span
                  or self._is_match_date != is_match_date
                  or self._is_match_version != is_match_version):
                raise ValueError(
                    'Spec setup should the same for all splits: %s.' %
                    split.pattern)

        if (self._is_match_span or self._is_match_date) and not range_config:
            range_config = range_config_pb2.RangeConfig(
                rolling_range=range_config_pb2.RollingRange(num_spans=1))
        if not self._is_match_span and not self._is_match_date and range_config:
            raise ValueError(
                'Span or Date spec should be specified in split pattern if RangeConfig is specified.'
            )

        self._input_base_uri = input_base_uri
        self._fingerprint = None
예제 #12
0
  def testSpanAlignWithRangeConfig(self):
    span1_split1 = os.path.join(self._input_base_path, 'span01', 'split1',
                                'data')
    io_utils.write_string_file(span1_split1, 'testing11')
    span2_split1 = os.path.join(self._input_base_path, 'span02', 'split1',
                                'data')
    io_utils.write_string_file(span2_split1, 'testing21')

    # Test static range in RangeConfig.
    range_config = range_config_pb2.RangeConfig(
        static_range=range_config_pb2.StaticRange(
            start_span_number=1, end_span_number=1))
    splits = [
        example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN:2}/split1/*')
    ]

    _, span, version = utils.calculate_splits_fingerprint_span_and_version(
        self._input_base_path, splits, range_config)
    self.assertEqual(span, 1)
    self.assertIsNone(version)
    self.assertEqual(splits[0].pattern, 'span01/split1/*')
예제 #13
0
    def testRangeConfigWithDateSpec(self):
        span1_split1 = os.path.join(self._input_base_path, '19700102',
                                    'split1', 'data')
        io_utils.write_string_file(span1_split1, 'testing11')

        start_span = utils.date_to_span_number(1970, 1, 2)
        end_span = utils.date_to_span_number(1970, 1, 2)
        range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(
                start_span_number=start_span, end_span_number=end_span))

        splits = [
            example_gen_pb2.Input.Split(name='s1',
                                        pattern='{YYYY}{MM}{DD}/split1/*')
        ]
        _, span, version = utils.calculate_splits_fingerprint_span_and_version(
            self._input_base_path, splits, range_config=range_config)

        self.assertEqual(span, 1)
        self.assertIsNone(version)
        self.assertEqual(splits[0].pattern, '19700102/split1/*')
예제 #14
0
    def testResolveExecProperties(self):
        # Create input dir.
        self._input_base_path = os.path.join(self._test_dir, 'input_base')
        fileio.makedirs(self._input_base_path)

        # Create exec proterties.
        self._exec_properties = {
            standard_component_specs.INPUT_BASE_KEY:
            self._input_base_path,
            standard_component_specs.INPUT_CONFIG_KEY:
            proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='s1',
                        pattern='span{SPAN:2}/version{VERSION:2}/split1/*'),
                    example_gen_pb2.Input.Split(
                        name='s2',
                        pattern='span{SPAN:2}/version{VERSION:2}/split2/*')
                ])),
            standard_component_specs.RANGE_CONFIG_KEY:
            None,
        }

        # Test align of span number.
        span1_v1_split1 = os.path.join(self._input_base_path, 'span01',
                                       'version01', 'split1', 'data')
        io_utils.write_string_file(span1_v1_split1, 'testing11')
        span1_v1_split2 = os.path.join(self._input_base_path, 'span01',
                                       'version01', 'split2', 'data')
        io_utils.write_string_file(span1_v1_split2, 'testing12')
        span2_v1_split1 = os.path.join(self._input_base_path, 'span02',
                                       'version01', 'split1', 'data')
        io_utils.write_string_file(span2_v1_split1, 'testing21')

        # Check that error raised when span does not match.
        with self.assertRaisesRegexp(
                ValueError, 'Latest span should be the same for each split'):
            self._file_based_driver.resolve_exec_properties(
                self._exec_properties, None, None)

        span2_v1_split2 = os.path.join(self._input_base_path, 'span02',
                                       'version01', 'split2', 'data')
        io_utils.write_string_file(span2_v1_split2, 'testing22')
        span2_v2_split1 = os.path.join(self._input_base_path, 'span02',
                                       'version02', 'split1', 'data')
        io_utils.write_string_file(span2_v2_split1, 'testing21')

        # Check that error raised when span matches, but version does not match.
        with self.assertRaisesRegexp(
                ValueError,
                'Latest version should be the same for each split'):
            self._file_based_driver.resolve_exec_properties(
                self._exec_properties, None, None)

        span2_v2_split2 = os.path.join(self._input_base_path, 'span02',
                                       'version02', 'split2', 'data')
        io_utils.write_string_file(span2_v2_split2, 'testing22')

        # Test if latest span and version selected when span and version aligns
        # for each split.
        self._file_based_driver.resolve_exec_properties(
            self._exec_properties, None, None)
        self.assertEqual(self._exec_properties[utils.SPAN_PROPERTY_NAME], 2)
        self.assertEqual(self._exec_properties[utils.VERSION_PROPERTY_NAME], 2)
        self.assertRegex(
            self._exec_properties[utils.FINGERPRINT_PROPERTY_NAME],
            r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*'
        )
        updated_input_config = example_gen_pb2.Input()
        proto_utils.json_to_proto(
            self._exec_properties[standard_component_specs.INPUT_CONFIG_KEY],
            updated_input_config)

        # Check if latest span is selected.
        self.assertProtoEquals(
            """
        splits {
          name: "s1"
          pattern: "span02/version02/split1/*"
        }
        splits {
          name: "s2"
          pattern: "span02/version02/split2/*"
        }""", updated_input_config)

        # Test driver behavior using RangeConfig with static range.
        self._exec_properties[
            standard_component_specs.
            INPUT_CONFIG_KEY] = proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='s1',
                        pattern='span{SPAN:2}/version{VERSION:2}/split1/*'),
                    example_gen_pb2.Input.Split(
                        name='s2',
                        pattern='span{SPAN:2}/version{VERSION:2}/split2/*'),
                ]))

        self._exec_properties[
            standard_component_specs.
            RANGE_CONFIG_KEY] = proto_utils.proto_to_json(
                range_config_pb2.RangeConfig(
                    static_range=range_config_pb2.StaticRange(
                        start_span_number=1, end_span_number=2)))
        with self.assertRaisesRegexp(
                ValueError, 'For ExampleGen, start and end span numbers'):
            self._file_based_driver.resolve_exec_properties(
                self._exec_properties, None, None)

        self._exec_properties[
            standard_component_specs.
            RANGE_CONFIG_KEY] = proto_utils.proto_to_json(
                range_config_pb2.RangeConfig(
                    static_range=range_config_pb2.StaticRange(
                        start_span_number=1, end_span_number=1)))
        self._file_based_driver.resolve_exec_properties(
            self._exec_properties, None, None)
        self.assertEqual(self._exec_properties[utils.SPAN_PROPERTY_NAME], 1)
        self.assertEqual(self._exec_properties[utils.VERSION_PROPERTY_NAME], 1)
        self.assertRegex(
            self._exec_properties[utils.FINGERPRINT_PROPERTY_NAME],
            r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*'
        )
        updated_input_config = example_gen_pb2.Input()
        proto_utils.json_to_proto(
            self._exec_properties[standard_component_specs.INPUT_CONFIG_KEY],
            updated_input_config)
        # Check if correct span inside static range is selected.
        self.assertProtoEquals(
            """
        splits {
          name: "s1"
          pattern: "span01/version01/split1/*"
        }
        splits {
          name: "s2"
          pattern: "span01/version01/split2/*"
        }""", updated_input_config)
예제 #15
0
파일: driver.py 프로젝트: jay90099/tfx
def _run_driver(executor_input: pipeline_spec_pb2.ExecutorInput) -> None:
    """Runs the driver, writing its output as a ExecutorOutput proto.

  The main goal of this driver is to calculate the span and fingerprint of input
  data, allowing for the executor invocation to be skipped if the ExampleGen
  component has been previously run on the same data with the same
  configuration. This span and fingerprint are added as new custom execution
  properties to an ExecutorOutput proto and written to a GCS path. The CAIP
  pipelines system reads this file and updates MLMD with the new execution
  properties.

  Args:
    executor_input: pipeline_spec_pb2.ExecutorInput that contains TFX artifacts
      and exec_properties information.
  """

    exec_properties = kubeflow_v2_entrypoint_utils.parse_execution_properties(
        executor_input.inputs.parameters)
    name_from_id = {}
    outputs_dict = kubeflow_v2_entrypoint_utils.parse_raw_artifact_dict(
        executor_input.outputs.artifacts, name_from_id)
    # A path at which an ExecutorOutput message will be
    # written with updated execution properties and output artifacts. The CAIP
    # Pipelines service will update the task's properties and artifacts prior to
    # running the executor.
    output_metadata_uri = executor_input.outputs.output_file

    logging.set_verbosity(logging.INFO)
    logging.info('exec_properties = %s\noutput_metadata_uri = %s',
                 exec_properties, output_metadata_uri)

    input_base_uri = exec_properties.get(
        standard_component_specs.INPUT_BASE_KEY)

    input_config = example_gen_pb2.Input()
    proto_utils.json_to_proto(
        exec_properties[standard_component_specs.INPUT_CONFIG_KEY],
        input_config)

    range_config = None
    range_config_entry = exec_properties.get(
        standard_component_specs.RANGE_CONFIG_KEY)
    if range_config_entry:
        range_config = range_config_pb2.RangeConfig()
        proto_utils.json_to_proto(range_config_entry, range_config)

    processor = input_processor.FileBasedInputProcessor(
        input_base_uri, input_config.splits, range_config)
    span, version = processor.resolve_span_and_version()
    fingerprint = processor.get_input_fingerprint(span, version)

    logging.info('Calculated span: %s', span)
    logging.info('Calculated fingerprint: %s', fingerprint)

    exec_properties[utils.SPAN_PROPERTY_NAME] = span
    exec_properties[utils.FINGERPRINT_PROPERTY_NAME] = fingerprint
    exec_properties[utils.VERSION_PROPERTY_NAME] = version

    # Updates the input_config.splits.pattern.
    for split in input_config.splits:
        split.pattern = processor.get_pattern_for_span_version(
            split.pattern, span, version)
    exec_properties[standard_component_specs.
                    INPUT_CONFIG_KEY] = proto_utils.proto_to_json(input_config)

    if standard_component_specs.EXAMPLES_KEY not in outputs_dict:
        raise ValueError(
            'Example artifact was missing in the ExampleGen outputs.')
    example_artifact = artifact_utils.get_single_instance(
        outputs_dict[standard_component_specs.EXAMPLES_KEY])

    driver.update_output_artifact(
        exec_properties=exec_properties,
        output_artifact=example_artifact.mlmd_artifact)

    # Log the output metadata file
    output_metadata = pipeline_spec_pb2.ExecutorOutput()
    output_metadata.parameters[utils.SPAN_PROPERTY_NAME].int_value = span
    output_metadata.parameters[
        utils.FINGERPRINT_PROPERTY_NAME].string_value = fingerprint
    if version is not None:
        output_metadata.parameters[
            utils.VERSION_PROPERTY_NAME].int_value = version
    output_metadata.parameters[
        standard_component_specs.
        INPUT_CONFIG_KEY].string_value = proto_utils.proto_to_json(
            input_config)
    output_metadata.artifacts[
        standard_component_specs.EXAMPLES_KEY].artifacts.add().CopyFrom(
            kubeflow_v2_entrypoint_utils.to_runtime_artifact(
                example_artifact, name_from_id))

    fileio.makedirs(os.path.dirname(output_metadata_uri))
    with fileio.open(output_metadata_uri, 'wb') as f:
        f.write(json_format.MessageToJson(output_metadata, sort_keys=True))
    def testPenguinPipelineLocalWithRollingWindow(self):
        examplegen_input_config = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='test', pattern='day{SPAN}/*'),
        ])
        resolver_range_config = range_config_pb2.RangeConfig(
            rolling_range=range_config_pb2.RollingRange(num_spans=2))

        def run_pipeline(examplegen_range_config):
            LocalDagRunner().run(
                penguin_pipeline_local._create_pipeline(
                    pipeline_name=self._pipeline_name,
                    data_root=self._data_root_span,
                    module_file=self._module_file,
                    accuracy_threshold=0.1,
                    serving_model_dir=self._serving_model_dir,
                    pipeline_root=self._pipeline_root,
                    metadata_path=self._metadata_path,
                    enable_tuning=False,
                    examplegen_input_config=examplegen_input_config,
                    examplegen_range_config=examplegen_range_config,
                    resolver_range_config=resolver_range_config,
                    beam_pipeline_args=[]))

        # Trigger the pipeline for the first span.
        examplegen_range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=1,
                                                      end_span_number=1))
        run_pipeline(examplegen_range_config)

        self.assertTrue(fileio.exists(self._serving_model_dir))
        self.assertTrue(fileio.exists(self._metadata_path))
        self.assertPipelineExecution(False)
        transform_execution_type = 'tfx.components.transform.component.Transform'
        trainer_execution_type = 'tfx.components.trainer.component.Trainer'
        expected_execution_count = 10  # 8 components + 2 resolver
        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            artifact_count = len(m.store.get_artifacts())
            execution_count = len(m.store.get_executions())
            self.assertGreaterEqual(artifact_count, execution_count)
            self.assertEqual(expected_execution_count, execution_count)
            # Verify Transform's input examples artifacts.
            tft_input_examples_artifacts = self._get_input_examples_artifacts(
                m.store, transform_execution_type)
            self.assertLen(tft_input_examples_artifacts, 1)
            # SpansResolver (controlled by resolver_range_config) returns span 1.
            self.assertEqual(
                '1', tft_input_examples_artifacts[0].custom_properties[
                    utils.SPAN_PROPERTY_NAME].string_value)

        # Trigger the pipeline for the second span.
        examplegen_range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=2,
                                                      end_span_number=2))
        run_pipeline(examplegen_range_config)

        with metadata.Metadata(metadata_config) as m:
            execution_count = len(m.store.get_executions())
            self.assertEqual(expected_execution_count * 2, execution_count)
            # Verify Transform's input examples artifacts.
            tft_input_examples_artifacts = self._get_input_examples_artifacts(
                m.store, transform_execution_type)
            self.assertLen(tft_input_examples_artifacts, 2)
            spans = {
                tft_input_examples_artifacts[0].custom_properties[
                    utils.SPAN_PROPERTY_NAME].string_value,
                tft_input_examples_artifacts[1].custom_properties[
                    utils.SPAN_PROPERTY_NAME].string_value
            }
            # SpansResolver (controlled by resolver_range_config) returns span 1 & 2.
            self.assertSetEqual({'1', '2'}, spans)
            # Verify Trainer's input examples artifacts.
            self.assertLen(
                self._get_input_examples_artifacts(m.store,
                                                   trainer_execution_type), 2)

        # Trigger the pipeline for the thrid span.
        examplegen_range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=3,
                                                      end_span_number=3))
        run_pipeline(examplegen_range_config)

        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            execution_count = len(m.store.get_executions())
            self.assertEqual(expected_execution_count * 3, execution_count)
            # Verify Transform's input examples artifacts.
            tft_input_examples_artifacts = self._get_input_examples_artifacts(
                m.store, transform_execution_type)
            self.assertLen(tft_input_examples_artifacts, 2)
            spans = {
                tft_input_examples_artifacts[0].custom_properties[
                    utils.SPAN_PROPERTY_NAME].string_value,
                tft_input_examples_artifacts[1].custom_properties[
                    utils.SPAN_PROPERTY_NAME].string_value
            }
            # SpansResolver (controlled by resolver_range_config) returns span 2 & 3.
            self.assertSetEqual({'2', '3'}, spans)
            # Verify Trainer's input examples artifacts.
            self.assertLen(
                self._get_input_examples_artifacts(m.store,
                                                   trainer_execution_type), 2)
예제 #17
0
파일: driver.py 프로젝트: numerology/tfx
def _run_driver(exec_properties: Dict[str, Any],
                outputs_dict: Dict[str, List[artifact.Artifact]],
                output_metadata_uri: str,
                name_from_id: Optional[Dict[int, str]] = None) -> None:
  """Runs the driver, writing its output as a ExecutorOutput proto.

  The main goal of this driver is to calculate the span and fingerprint of input
  data, allowing for the executor invocation to be skipped if the ExampleGen
  component has been previously run on the same data with the same
  configuration. This span and fingerprint are added as new custom execution
  properties to an ExecutorOutput proto and written to a GCS path. The CAIP
  pipelines system reads this file and updates MLMD with the new execution
  properties.


  Args:
    exec_properties:
      These are required to contain the following properties:
      'input_base_uri': A path from which files will be read and their
        span/fingerprint calculated.
      'input_config': A json-serialized tfx.proto.example_gen_pb2.InputConfig
        proto message.
        See https://www.tensorflow.org/tfx/guide/examplegen for more details.
      'output_config': A json-serialized tfx.proto.example_gen_pb2.OutputConfig
        proto message.
        See https://www.tensorflow.org/tfx/guide/examplegen for more details.
    outputs_dict: The mapping of the output artifacts.
    output_metadata_uri: A path at which an ExecutorOutput message will be
      written with updated execution properties and output artifacts. The CAIP
      Pipelines service will update the task's properties and artifacts prior to
      running the executor.
    name_from_id: Optional. Mapping from the converted int-typed id to str-typed
      runtime artifact name, which should be unique.
  """
  if name_from_id is None:
    name_from_id = {}

  logging.set_verbosity(logging.INFO)
  logging.info('exec_properties = %s\noutput_metadata_uri = %s',
               exec_properties, output_metadata_uri)

  input_base_uri = exec_properties.get(standard_component_specs.INPUT_BASE_KEY)

  input_config = example_gen_pb2.Input()
  proto_utils.json_to_proto(
      exec_properties[standard_component_specs.INPUT_CONFIG_KEY], input_config)

  range_config = None
  range_config_entry = exec_properties.get(
      standard_component_specs.RANGE_CONFIG_KEY)
  if range_config_entry:
    range_config = range_config_pb2.RangeConfig()
    proto_utils.json_to_proto(range_config_entry, range_config)

  processor = input_processor.FileBasedInputProcessor(input_base_uri,
                                                      input_config.splits,
                                                      range_config)
  span, version = processor.resolve_span_and_version()
  fingerprint = processor.get_input_fingerprint(span, version)

  logging.info('Calculated span: %s', span)
  logging.info('Calculated fingerprint: %s', fingerprint)

  exec_properties[utils.SPAN_PROPERTY_NAME] = span
  exec_properties[utils.FINGERPRINT_PROPERTY_NAME] = fingerprint
  exec_properties[utils.VERSION_PROPERTY_NAME] = version

  # Updates the input_config.splits.pattern.
  for split in input_config.splits:
    split.pattern = processor.get_pattern_for_span_version(
        split.pattern, span, version)
  exec_properties[standard_component_specs
                  .INPUT_CONFIG_KEY] = proto_utils.proto_to_json(input_config)

  if standard_component_specs.EXAMPLES_KEY not in outputs_dict:
    raise ValueError('Example artifact was missing in the ExampleGen outputs.')
  example_artifact = artifact_utils.get_single_instance(
      outputs_dict[standard_component_specs.EXAMPLES_KEY])

  driver.update_output_artifact(
      exec_properties=exec_properties,
      output_artifact=example_artifact.mlmd_artifact)

  # Log the output metadata file
  output_metadata = pipeline_pb2.ExecutorOutput()
  output_metadata.parameters[
      utils.FINGERPRINT_PROPERTY_NAME].string_value = fingerprint
  output_metadata.parameters[utils.SPAN_PROPERTY_NAME].string_value = str(span)
  output_metadata.parameters[
      standard_component_specs
      .INPUT_CONFIG_KEY].string_value = json_format.MessageToJson(input_config)
  output_metadata.artifacts[
      standard_component_specs.EXAMPLES_KEY].artifacts.add().CopyFrom(
          kubeflow_v2_entrypoint_utils.to_runtime_artifact(
              example_artifact, name_from_id))

  fileio.makedirs(os.path.dirname(output_metadata_uri))
  with fileio.open(output_metadata_uri, 'wb') as f:
    f.write(json_format.MessageToJson(output_metadata, sort_keys=True))