コード例 #1
0
    def testVersionWidth(self):
        split1 = os.path.join(self._input_base_path, 'span1', 'ver1', 'split1',
                              'data')
        io_utils.write_string_file(split1, 'testing')

        splits = [
            example_gen_pb2.Input.Split(
                name='s1', pattern='span{SPAN}/ver{VERSION:2}/split1/*')
        ]

        # TODO(jjma): find a better way of describing this error to user.
        with self.assertRaisesRegexp(
                ValueError, 'Glob pattern does not match regex pattern'):
            utils.calculate_splits_fingerprint_span_and_version(
                self._input_base_path, splits)

        splits = [
            example_gen_pb2.Input.Split(
                name='s1', pattern='span{SPAN}/ver{VERSION:1}/split1/*')
        ]

        _, span, version = utils.calculate_splits_fingerprint_span_and_version(
            self._input_base_path, splits)
        self.assertEqual(span, 1)
        self.assertEqual(version, 1)
コード例 #2
0
    def testRangeConfigSpanWidthPresence(self):
        # Test RangeConfig.static_range behavior when span width is not given.
        span1_split1 = os.path.join(self._input_base_path, 'span01', 'split1',
                                    'data')
        io_utils.write_string_file(span1_split1, 'testing11')

        range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=1,
                                                      end_span_number=1))
        splits1 = [
            example_gen_pb2.Input.Split(name='s1',
                                        pattern='span{SPAN}/split1/*')
        ]

        # RangeConfig cannot find zero padding span without width modifier.
        with self.assertRaisesRegexp(ValueError,
                                     'Cannot find matching for split'):
            utils.calculate_splits_fingerprint_span_and_version(
                self._input_base_path, splits1, range_config=range_config)

        splits2 = [
            example_gen_pb2.Input.Split(name='s1',
                                        pattern='span{SPAN:2}/split1/*')
        ]

        # With width modifier in span spec, RangeConfig.static_range makes
        # correct zero-padded substitution.
        _, span, version = utils.calculate_splits_fingerprint_span_and_version(
            self._input_base_path, splits2, range_config=range_config)
        self.assertEqual(span, 1)
        self.assertIsNone(version)
        self.assertEqual(splits2[0].pattern, 'span01/split1/*')
コード例 #3
0
    def testCalculateSplitsFingerprintSpanAndVersionWithDate(self):
        # Test align of span and version numbers.
        span1_v1_split1 = os.path.join(self._input_base_path, '19700102',
                                       'ver01', 'split1', 'data')
        io_utils.write_string_file(span1_v1_split1, 'testing11')
        span1_v1_split2 = os.path.join(self._input_base_path, '19700102',
                                       'ver01', 'split2', 'data')
        io_utils.write_string_file(span1_v1_split2, 'testing12')
        span2_v1_split1 = os.path.join(self._input_base_path, '19700103',
                                       'ver01', 'split1', 'data')
        io_utils.write_string_file(span2_v1_split1, 'testing21')

        # Test if error raised when date does not align.
        splits = [
            example_gen_pb2.Input.Split(
                name='s1', pattern='{YYYY}{MM}{DD}/ver{VERSION}/split1/*'),
            example_gen_pb2.Input.Split(
                name='s2', pattern='{YYYY}{MM}{DD}/ver{VERSION}/split2/*')
        ]
        with self.assertRaisesRegexp(
                ValueError, 'Latest span should be the same for each split'):
            utils.calculate_splits_fingerprint_span_and_version(
                self._input_base_path, splits)

        span2_v1_split2 = os.path.join(self._input_base_path, '19700103',
                                       'ver01', 'split2', 'data')
        io_utils.write_string_file(span2_v1_split2, 'testing22')
        span2_v2_split1 = os.path.join(self._input_base_path, '19700103',
                                       'ver02', 'split1', 'data')
        io_utils.write_string_file(span2_v2_split1, 'testing21')

        # Test if error raised when date aligns but version does not.
        splits = [
            example_gen_pb2.Input.Split(
                name='s1', pattern='{YYYY}{MM}{DD}/ver{VERSION}/split1/*'),
            example_gen_pb2.Input.Split(
                name='s2', pattern='{YYYY}{MM}{DD}/ver{VERSION}/split2/*')
        ]
        with self.assertRaisesRegexp(
                ValueError,
                'Latest version should be the same for each split'):
            utils.calculate_splits_fingerprint_span_and_version(
                self._input_base_path, splits)
        span2_v2_split2 = os.path.join(self._input_base_path, '19700103',
                                       'ver02', 'split2', 'data')
        io_utils.write_string_file(span2_v2_split2, 'testing22')

        # Test if latest span and version is selected when aligned for each split.
        splits = [
            example_gen_pb2.Input.Split(
                name='s1', pattern='{YYYY}{MM}{DD}/ver{VERSION}/split1/*'),
            example_gen_pb2.Input.Split(
                name='s2', pattern='{YYYY}{MM}{DD}/ver{VERSION}/split2/*')
        ]
        _, span, version = utils.calculate_splits_fingerprint_span_and_version(
            self._input_base_path, splits)
        self.assertEqual(span, 2)
        self.assertEqual(version, 2)
        self.assertEqual(splits[0].pattern, '19700103/ver02/split1/*')
        self.assertEqual(splits[1].pattern, '19700103/ver02/split2/*')
コード例 #4
0
 def testDateSpecPartiallyMissing(self):
   splits1 = [
       example_gen_pb2.Input.Split(name='s1', pattern='{YYYY}-{MM}/split1/*')
   ]
   with self.assertRaisesRegex(ValueError, 'Exactly one of each date spec'):
     utils.calculate_splits_fingerprint_span_and_version(
         self._input_base_path, splits1)
コード例 #5
0
 def testSpanNoMatching(self):
   splits = [
       example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN}/split1/*'),
       example_gen_pb2.Input.Split(name='s2', pattern='span{SPAN}/split2/*')
   ]
   with self.assertRaisesRegex(ValueError, 'Cannot find matching for split'):
     utils.calculate_splits_fingerprint_span_and_version(
         self._input_base_path, splits)
コード例 #6
0
  def testSpanInvalidWidth(self):
    splits = [
        example_gen_pb2.Input.Split(name='s1', pattern='{SPAN:x}/split1/*')
    ]

    with self.assertRaisesRegex(
        ValueError, 'Width modifier in span spec is not a positive integer'):
      utils.calculate_splits_fingerprint_span_and_version(
          self._input_base_path, splits)
コード例 #7
0
  def testSpanWrongFormat(self):
    wrong_span = os.path.join(self._input_base_path, 'spanx', 'split1', 'data')
    io_utils.write_string_file(wrong_span, 'testing_wrong_span')

    splits = [
        example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN}/split1/*')
    ]
    with self.assertRaisesRegex(ValueError, 'Cannot find span number'):
      utils.calculate_splits_fingerprint_span_and_version(
          self._input_base_path, splits)
コード例 #8
0
  def testBothSpanAndDate(self):
    splits = [
        example_gen_pb2.Input.Split(
            name='s1', pattern='{YYYY}-{MM}-{DD}/{SPAN}/split1/*')
    ]

    with self.assertRaisesRegex(
        ValueError,
        'Either span spec or date specs must be specified exclusively'):
      utils.calculate_splits_fingerprint_span_and_version(
          self._input_base_path, splits)
コード例 #9
0
  def testVersionNoMatching(self):
    span_dir = os.path.join(self._input_base_path, 'span01', 'wrong', 'data')
    io_utils.write_string_file(span_dir, 'testing_version_no_matching')

    splits = [
        example_gen_pb2.Input.Split(
            name='s1', pattern='span{SPAN}/version{VERSION}/split1/*')
    ]
    with self.assertRaisesRegex(ValueError, 'Cannot find matching for split'):
      utils.calculate_splits_fingerprint_span_and_version(
          self._input_base_path, splits)
コード例 #10
0
 def testHaveVersionNoSpan(self):
     # Test specific behavior when Version spec is present but Span is not.
     splits = [
         example_gen_pb2.Input.Split(name='s1',
                                     pattern='version{VERSION}/split1/*')
     ]
     with self.assertRaisesRegexp(
             ValueError,
             'Version spec provided, but Span or Date spec is not present'):
         utils.calculate_splits_fingerprint_span_and_version(
             self._input_base_path, splits)
コード例 #11
0
  def testInvalidDate(self):
    split1 = os.path.join(self._input_base_path, '20201301', 'split1', 'data')
    io_utils.write_string_file(split1, 'testing')

    splits = [
        example_gen_pb2.Input.Split(
            name='s1', pattern='{YYYY}{MM}{DD}/split1/*')
    ]

    with self.assertRaisesRegex(ValueError, 'Retrieved date is invalid'):
      utils.calculate_splits_fingerprint_span_and_version(
          self._input_base_path, splits)
コード例 #12
0
  def testDateBadFormat(self):
    # Test improperly formed date.
    split1 = os.path.join(self._input_base_path, 'yyyymmdd', 'split1', 'data')
    io_utils.write_string_file(split1, 'testing')

    splits = [
        example_gen_pb2.Input.Split(
            name='s1', pattern='{YYYY}{MM}{DD}/split1/*')
    ]

    with self.assertRaisesRegex(ValueError,
                                'Cannot find span number using date'):
      utils.calculate_splits_fingerprint_span_and_version(
          self._input_base_path, splits)
コード例 #13
0
    def resolve_exec_properties(
        self,
        exec_properties: Dict[Text, Any],
        pipeline_info: data_types.PipelineInfo,
        component_info: data_types.ComponentInfo,
    ) -> Dict[Text, Any]:
        """Overrides BaseDriver.resolve_exec_properties()."""
        del pipeline_info, component_info

        input_config = example_gen_pb2.Input()
        json_format.Parse(exec_properties[utils.INPUT_CONFIG_KEY],
                          input_config)

        input_base = exec_properties[utils.INPUT_BASE_KEY]
        logging.debug('Processing input %s.', input_base)

        # Note that this function updates the input_config.splits.pattern.
        fingerprint, span, version = utils.calculate_splits_fingerprint_span_and_version(
            input_base, input_config.splits)

        exec_properties[utils.INPUT_CONFIG_KEY] = json_format.MessageToJson(
            input_config, sort_keys=True, preserving_proto_field_name=True)
        exec_properties[utils.SPAN_PROPERTY_NAME] = span
        exec_properties[utils.VERSION_PROPERTY_NAME] = version
        exec_properties[utils.FINGERPRINT_PROPERTY_NAME] = fingerprint

        return exec_properties
コード例 #14
0
  def testRangeConfigWithNonexistentSpan(self):
    # Test behavior when specified span in RangeConfig does not exist.
    span1_split1 = os.path.join(self._input_base_path, 'span01', 'split1',
                                'data')
    io_utils.write_string_file(span1_split1, 'testing11')

    range_config = range_config_pb2.RangeConfig(
        static_range=range_config_pb2.StaticRange(
            start_span_number=2, end_span_number=2))
    splits = [
        example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN:2}/split1/*')
    ]

    with self.assertRaisesRegex(ValueError, 'Cannot find matching for split'):
      utils.calculate_splits_fingerprint_span_and_version(
          self._input_base_path, splits, range_config=range_config)
コード例 #15
0
 def resolve_span_and_version(self) -> Tuple[int, Optional[int]]:
     # TODO(b/181275944): refactor to use base resolve_span_and_version.
     splits = []
     for split in self._splits:
         splits.append(copy.deepcopy(split))
     self._fingerprint, span, version = utils.calculate_splits_fingerprint_span_and_version(
         self._input_base_uri, splits, self._range_config)
     return span, version
コード例 #16
0
ファイル: utils_test.py プロジェクト: jbmunro4/tfx
    def testMultipleSpecs(self):
        splits1 = [
            example_gen_pb2.Input.Split(
                name='s1', pattern='span1{SPAN}/span2{SPAN}/split1/*')
        ]
        with self.assertRaisesRegexp(ValueError, 'Only one {SPAN} is allowed'):
            utils.calculate_splits_fingerprint_span_and_version(
                self._input_base_path, splits1)

        splits2 = [
            example_gen_pb2.Input.Split(
                name='s1',
                pattern='span{SPAN}/ver1{VERSION}/ver2{VERSION}/split1/*')
        ]
        with self.assertRaisesRegexp(ValueError,
                                     'Only one {VERSION} is allowed'):
            utils.calculate_splits_fingerprint_span_and_version(
                self._input_base_path, splits2)
コード例 #17
0
    def testNoSpanOrVersion(self):
        # Test specific behavior when neither Span nor Version spec is present.
        split1 = os.path.join(self._input_base_path, 'split1', 'data')
        io_utils.write_string_file(split1, 'testing')

        splits = [example_gen_pb2.Input.Split(name='s1', pattern='split1/*')]

        _, span, version = utils.calculate_splits_fingerprint_span_and_version(
            self._input_base_path, splits)
        self.assertEqual(span, 0)
        self.assertIsNone(version)
コード例 #18
0
  def testHaveSpanNoVersion(self):
    # Test specific behavior when Span spec is present but Version is not.
    split1 = os.path.join(self._input_base_path, 'span1', 'split1', 'data')
    io_utils.write_string_file(split1, 'testing')

    splits = [
        example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN}/split1/*')
    ]

    _, span, version = utils.calculate_splits_fingerprint_span_and_version(
        self._input_base_path, splits)
    self.assertEqual(span, 1)
    self.assertIsNone(version)
コード例 #19
0
    def testSpanVersionWidthNoSeperator(self):
        split1 = os.path.join(self._input_base_path, '1234', 'split1', 'data')
        io_utils.write_string_file(split1, 'testing')

        splits = [
            example_gen_pb2.Input.Split(name='s1',
                                        pattern='{SPAN:2}{VERSION:2}/split1/*')
        ]

        _, span, version = utils.calculate_splits_fingerprint_span_and_version(
            self._input_base_path, splits)
        self.assertEqual(span, 12)
        self.assertEqual(version, 34)
コード例 #20
0
ファイル: utils_test.py プロジェクト: jbmunro4/tfx
    def testHaveSpanAndVersion(self):
        # Test specific behavior when both Span and Version are present.
        split1 = os.path.join(self._input_base_path, 'span1', 'version1',
                              'split1', 'data')
        io_utils.write_string_file(split1, 'testing')

        splits = [
            example_gen_pb2.Input.Split(
                name='s1', pattern='span{SPAN}/version{VERSION}/split1/*')
        ]

        _, span, version = utils.calculate_splits_fingerprint_span_and_version(
            self._input_base_path, splits)
        self.assertEqual('1', span)
        self.assertEqual('1', version)
コード例 #21
0
    def testHaveDateAndVersion(self):
        # Test specific behavior when both Date and Version are present.
        split1 = os.path.join(self._input_base_path, '19700102', 'ver1',
                              'split1', 'data')
        io_utils.write_string_file(split1, 'testing')

        splits = [
            example_gen_pb2.Input.Split(
                name='s1', pattern='{YYYY}{MM}{DD}/ver{VERSION}/split1/*')
        ]

        _, span, version = utils.calculate_splits_fingerprint_span_and_version(
            self._input_base_path, splits)
        self.assertEqual(span, 1)
        self.assertEqual(version, 1)
コード例 #22
0
    def resolve_exec_properties(
        self,
        exec_properties: Dict[Text, Any],
        pipeline_info: data_types.PipelineInfo,
        component_info: data_types.ComponentInfo,
    ) -> Dict[Text, Any]:
        """Overrides BaseDriver.resolve_exec_properties()."""
        del pipeline_info, component_info

        input_config = example_gen_pb2.Input()
        proto_utils.json_to_proto(
            exec_properties[standard_component_specs.INPUT_CONFIG_KEY],
            input_config)

        input_base = exec_properties[standard_component_specs.INPUT_BASE_KEY]
        logging.debug('Processing input %s.', input_base)

        range_config = None
        range_config_entry = exec_properties.get(
            standard_component_specs.RANGE_CONFIG_KEY)
        if range_config_entry:
            range_config = range_config_pb2.RangeConfig()
            proto_utils.json_to_proto(range_config_entry, range_config)

            if range_config.HasField('static_range'):
                # For ExampleGen, StaticRange must specify an exact span to look for,
                # since only one span is processed at a time.
                start_span_number = range_config.static_range.start_span_number
                end_span_number = range_config.static_range.end_span_number
                if start_span_number != end_span_number:
                    raise ValueError(
                        'Start and end span numbers for RangeConfig.static_range must '
                        'be equal: (%s, %s)' %
                        (start_span_number, end_span_number))

        # Note that this function updates the input_config.splits.pattern.
        fingerprint, span, version = utils.calculate_splits_fingerprint_span_and_version(
            input_base, input_config.splits, range_config)

        exec_properties[standard_component_specs.
                        INPUT_CONFIG_KEY] = proto_utils.proto_to_json(
                            input_config)
        exec_properties[utils.SPAN_PROPERTY_NAME] = span
        exec_properties[utils.VERSION_PROPERTY_NAME] = version
        exec_properties[utils.FINGERPRINT_PROPERTY_NAME] = fingerprint

        return exec_properties
コード例 #23
0
    def testNewSpanWithOlderVersionAlign(self):
        # Test specific behavior when a newer Span has older Version.
        span1_ver2 = os.path.join(self._input_base_path, 'span1', 'ver2',
                                  'split1', 'data')
        io_utils.write_string_file(span1_ver2, 'testing')
        span2_ver1 = os.path.join(self._input_base_path, 'span2', 'ver1',
                                  'split1', 'data')
        io_utils.write_string_file(span2_ver1, 'testing')

        splits = [
            example_gen_pb2.Input.Split(
                name='s1', pattern='span{SPAN}/ver{VERSION}/split1/*')
        ]

        _, span, version = utils.calculate_splits_fingerprint_span_and_version(
            self._input_base_path, splits)
        self.assertEqual(span, 2)
        self.assertEqual(version, 1)
コード例 #24
0
    def testCalculateSplitsFingerprint(self):
        split1 = os.path.join(self._input_base_path, 'split1', 'data')
        io_utils.write_string_file(split1, 'testing')
        os.utime(split1, (0, 1))
        split2 = os.path.join(self._input_base_path, 'split2', 'data')
        io_utils.write_string_file(split2, 'testing2')
        os.utime(split2, (0, 3))

        splits = [
            example_gen_pb2.Input.Split(name='s1', pattern='split1/*'),
            example_gen_pb2.Input.Split(name='s2', pattern='split2/*')
        ]
        fingerprint, span, version = utils.calculate_splits_fingerprint_span_and_version(
            self._input_base_path, splits)
        self.assertEqual(
            fingerprint,
            'split:s1,num_files:1,total_bytes:7,xor_checksum:1,sum_checksum:1\n'
            'split:s2,num_files:1,total_bytes:8,xor_checksum:3,sum_checksum:3')
        self.assertEqual(span, 0)
        self.assertIsNone(version)
コード例 #25
0
    def testRangeConfigWithDateSpec(self):
        span1_split1 = os.path.join(self._input_base_path, '19700102',
                                    'split1', 'data')
        io_utils.write_string_file(span1_split1, 'testing11')

        start_span = utils.date_to_span_number(1970, 1, 2)
        end_span = utils.date_to_span_number(1970, 1, 2)
        range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(
                start_span_number=start_span, end_span_number=end_span))

        splits = [
            example_gen_pb2.Input.Split(name='s1',
                                        pattern='{YYYY}{MM}{DD}/split1/*')
        ]
        _, span, version = utils.calculate_splits_fingerprint_span_and_version(
            self._input_base_path, splits, range_config=range_config)

        self.assertEqual(span, 1)
        self.assertIsNone(version)
        self.assertEqual(splits[0].pattern, '19700102/split1/*')
コード例 #26
0
  def testSpanAlignWithRangeConfig(self):
    span1_split1 = os.path.join(self._input_base_path, 'span01', 'split1',
                                'data')
    io_utils.write_string_file(span1_split1, 'testing11')
    span2_split1 = os.path.join(self._input_base_path, 'span02', 'split1',
                                'data')
    io_utils.write_string_file(span2_split1, 'testing21')

    # Test static range in RangeConfig.
    range_config = range_config_pb2.RangeConfig(
        static_range=range_config_pb2.StaticRange(
            start_span_number=1, end_span_number=1))
    splits = [
        example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN:2}/split1/*')
    ]

    _, span, version = utils.calculate_splits_fingerprint_span_and_version(
        self._input_base_path, splits, range_config)
    self.assertEqual(span, 1)
    self.assertIsNone(version)
    self.assertEqual(splits[0].pattern, 'span01/split1/*')
コード例 #27
0
def _run_driver(exec_properties: Dict[str, Any],
                outputs_dict: Dict[str, List[artifact.Artifact]],
                output_metadata_uri: str,
                name_from_id: Optional[Dict[int, str]] = None) -> None:
    """Runs the driver, writing its output as a ExecutorOutput proto.

  The main goal of this driver is to calculate the span and fingerprint of input
  data, allowing for the executor invocation to be skipped if the ExampleGen
  component has been previously run on the same data with the same
  configuration. This span and fingerprint are added as new custom execution
  properties to an ExecutorOutput proto and written to a GCS path. The CAIP
  pipelines system reads this file and updates MLMD with the new execution
  properties.


  Args:
    exec_properties:
      These are required to contain the following properties:
      'input_base_uri': A path from which files will be read and their
        span/fingerprint calculated.
      'input_config': A json-serialized tfx.proto.example_gen_pb2.InputConfig
        proto message.
        See https://www.tensorflow.org/tfx/guide/examplegen for more details.
      'output_config': A json-serialized tfx.proto.example_gen_pb2.OutputConfig
        proto message.
        See https://www.tensorflow.org/tfx/guide/examplegen for more details.
    outputs_dict: The mapping of the output artifacts.
    output_metadata_uri: A path at which an ExecutorOutput message will be
      written with updated execution properties and output artifacts. The CAIP
      Pipelines service will update the task's properties and artifacts prior to
      running the executor.
    name_from_id: Optional. Mapping from the converted int-typed id to str-typed
      runtime artifact name, which should be unique.
  """
    if name_from_id is None:
        name_from_id = {}

    logging.set_verbosity(logging.INFO)
    logging.info('exec_properties = %s\noutput_metadata_uri = %s',
                 exec_properties, output_metadata_uri)

    input_base_uri = exec_properties[utils.INPUT_BASE_KEY]

    input_config = example_gen_pb2.Input()
    proto_utils.json_to_proto(exec_properties[utils.INPUT_CONFIG_KEY],
                              input_config)

    # TODO(b/161734559): Support range config.
    fingerprint, select_span, version = utils.calculate_splits_fingerprint_span_and_version(
        input_base_uri, input_config.splits)
    logging.info('Calculated span: %s', select_span)
    logging.info('Calculated fingerprint: %s', fingerprint)

    exec_properties[utils.SPAN_PROPERTY_NAME] = select_span
    exec_properties[utils.FINGERPRINT_PROPERTY_NAME] = fingerprint
    exec_properties[utils.VERSION_PROPERTY_NAME] = version

    if utils.EXAMPLES_KEY not in outputs_dict:
        raise ValueError(
            'Example artifact was missing in the ExampleGen outputs.')
    example_artifact = artifact_utils.get_single_instance(
        outputs_dict[utils.EXAMPLES_KEY])

    driver.update_output_artifact(
        exec_properties=exec_properties,
        output_artifact=example_artifact.mlmd_artifact)

    # Log the output metadata file
    output_metadata = pipeline_pb2.ExecutorOutput()
    output_metadata.parameters[
        utils.FINGERPRINT_PROPERTY_NAME].string_value = fingerprint
    output_metadata.parameters[utils.SPAN_PROPERTY_NAME].string_value = str(
        select_span)
    output_metadata.parameters[
        utils.INPUT_CONFIG_KEY].string_value = json_format.MessageToJson(
            input_config)
    output_metadata.artifacts[utils.EXAMPLES_KEY].artifacts.add().CopyFrom(
        kubeflow_v2_entrypoint_utils.to_runtime_artifact(
            example_artifact, name_from_id))

    fileio.makedirs(os.path.dirname(output_metadata_uri))
    with fileio.open(output_metadata_uri, 'wb') as f:
        f.write(json_format.MessageToJson(output_metadata, sort_keys=True))