Пример #1
0
 def test_set_payload_format(self):
     examples = standard_artifacts.Examples()
     examples_utils.set_payload_format(
         examples, example_gen_pb2.PayloadFormat.FORMAT_PROTO)
     self.assertEqual(
         examples.get_string_custom_property(
             utils.PAYLOAD_FORMAT_PROPERTY_NAME), 'FORMAT_PROTO')
Пример #2
0
 def test_get_tfxio_factory_from_artifact(self,
                                          payload_format,
                                          expected_tfxio_type,
                                          raw_record_column_name=None,
                                          provide_data_view_uri=False,
                                          read_as_raw_records=False):
     examples = standard_artifacts.Examples()
     if payload_format is not None:
         examples_utils.set_payload_format(examples, payload_format)
     data_view_uri = None
     if provide_data_view_uri:
         data_view_uri = tempfile.mkdtemp(dir=self.get_temp_dir())
         tf_graph_record_decoder.save_decoder(_SimpleTfGraphRecordDecoder(),
                                              data_view_uri)
     if data_view_uri is not None:
         examples.set_string_custom_property(
             constants.DATA_VIEW_URI_PROPERTY_KEY, data_view_uri)
     tfxio_factory = tfxio_utils.get_tfxio_factory_from_artifact(
         examples, _TELEMETRY_DESCRIPTORS, _SCHEMA, read_as_raw_records,
         raw_record_column_name)
     tfxio = tfxio_factory(_FAKE_FILE_PATTERN)
     self.assertIsInstance(tfxio, expected_tfxio_type)
     # We currently only create RecordBasedTFXIO and the check below relies on
     # that.
     self.assertIsInstance(tfxio, record_based_tfxio.RecordBasedTFXIO)
     self.assertEqual(tfxio.telemetry_descriptors, _TELEMETRY_DESCRIPTORS)
     self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name)
     # Since we provide a schema, ArrowSchema() should not raise.
     _ = tfxio.ArrowSchema()
Пример #3
0
 def test_raise_if_data_view_uri_not_available(self):
     examples = standard_artifacts.Examples()
     examples_utils.set_payload_format(
         examples, example_gen_pb2.PayloadFormat.FORMAT_PROTO)
     with self.assertRaisesRegex(AssertionError, 'requires a DataView'):
         tfxio_utils.get_tfxio_factory_from_artifact(
             examples, _TELEMETRY_DESCRIPTORS)(_FAKE_FILE_PATTERN)
Пример #4
0
 def test_resolve_payload_format_and_data_view_uri(
         self,
         payload_formats,
         data_view_uris=None,
         data_view_ids=None,
         expected_payload_format=None,
         expected_data_view_uri=None,
         expected_error_type=None,
         expected_error_msg_regex=None):
     examples = []
     if data_view_uris is None:
         data_view_uris = [None] * len(payload_formats)
     if data_view_ids is None:
         data_view_ids = [None] * len(payload_formats)
     for payload_format, data_view_uri, data_view_id in zip(
             payload_formats, data_view_uris, data_view_ids):
         artifact = standard_artifacts.Examples()
         examples_utils.set_payload_format(artifact, payload_format)
         if data_view_uri is not None:
             artifact.set_string_custom_property(
                 constants.DATA_VIEW_URI_PROPERTY_KEY, data_view_uri)
         if data_view_id is not None:
             artifact.set_int_custom_property(
                 constants.DATA_VIEW_ID_PROPERTY_KEY, data_view_id)
         examples.append(artifact)
     if expected_error_type is None:
         payload_format, data_view_uri = (
             tfxio_utils.resolve_payload_format_and_data_view_uri(examples))
         self.assertEqual(payload_format, expected_payload_format)
         self.assertEqual(data_view_uri, expected_data_view_uri)
     else:
         with self.assertRaisesRegex(expected_error_type,
                                     expected_error_msg_regex):
             _ = tfxio_utils.resolve_payload_format_and_data_view_uri(
                 examples)
Пример #5
0
  def test_get_tfxio_factory_from_artifact_data_view_legacy(self):
    # This tests FORMAT_PROTO with data view where the DATA_VIEW_CREATE_TIME_KEY
    # is an int value. This is a legacy property type and should be string type
    # in the future.
    if tf.__version__ < '2':
      self.skipTest('DataView is not supported under TF 1.x.')

    examples = standard_artifacts.Examples()
    examples_utils.set_payload_format(
        examples, example_gen_pb2.PayloadFormat.FORMAT_PROTO)
    data_view_uri = tempfile.mkdtemp(dir=self.get_temp_dir())
    tf_graph_record_decoder.save_decoder(_SimpleTfGraphRecordDecoder(),
                                         data_view_uri)
    examples.set_string_custom_property(constants.DATA_VIEW_URI_PROPERTY_KEY,
                                        data_view_uri)
    examples.set_int_custom_property(constants.DATA_VIEW_CREATE_TIME_KEY, '1')
    tfxio_factory = tfxio_utils.get_tfxio_factory_from_artifact(
        [examples],
        _TELEMETRY_DESCRIPTORS,
        _SCHEMA,
        read_as_raw_records=False,
        raw_record_column_name=None)
    tfxio = tfxio_factory(_FAKE_FILE_PATTERN)
    self.assertIsInstance(tfxio, record_to_tensor_tfxio.TFRecordToTensorTFXIO)
    # We currently only create RecordBasedTFXIO and the check below relies on
    # that.
    self.assertIsInstance(tfxio, record_based_tfxio.RecordBasedTFXIO)
    self.assertEqual(tfxio.telemetry_descriptors, _TELEMETRY_DESCRIPTORS)
    # Since we provide a schema, ArrowSchema() should not raise.
    _ = tfxio.ArrowSchema()
Пример #6
0
    def test_get_tf_dataset_factory_from_artifact(self):
        examples = standard_artifacts.Examples()
        examples_utils.set_payload_format(
            examples, example_gen_pb2.PayloadFormat.FORMAT_TF_EXAMPLE)

        dataset_factory = tfxio_utils.get_tf_dataset_factory_from_artifact(
            [examples], _TELEMETRY_DESCRIPTORS)
        self.assertIsInstance(dataset_factory, Callable)
        self.assertEqual(tf.data.Dataset,
                         inspect.signature(dataset_factory).return_annotation)
Пример #7
0
  def test_get_record_batch_factory_from_artifact(self):
    examples = standard_artifacts.Examples()
    examples_utils.set_payload_format(
        examples, example_gen_pb2.PayloadFormat.FORMAT_TF_EXAMPLE)

    record_batch_factory = tfxio_utils.get_record_batch_factory_from_artifact(
        [examples], _TELEMETRY_DESCRIPTORS)
    self.assertIsInstance(record_batch_factory, Callable)
    self.assertEqual(Iterator[pa.RecordBatch],
                     inspect.signature(record_batch_factory).return_annotation)
Пример #8
0
  def Do(
      self,
      input_dict: Dict[Text, List[types.Artifact]],
      output_dict: Dict[Text, List[types.Artifact]],
      exec_properties: Dict[Text, Any],
  ) -> None:
    """Take input data source and generates serialized data splits.

    The output is intended to be serialized tf.train.Examples or
    tf.train.SequenceExamples protocol buffer in gzipped TFRecord format,
    but subclasses can choose to override to write to any serialized records
    payload into gzipped TFRecord as specified, so long as downstream
    component can consume it. The format of payload is added to
    `payload_format` custom property of the output Example artifact.

    Args:
      input_dict: Input dict from input key to a list of Artifacts. Depends on
        detailed example gen implementation.
      output_dict: Output dict from output key to a list of Artifacts.
        - examples: splits of serialized records.
      exec_properties: A dict of execution properties. Depends on detailed
        example gen implementation.
        - input_base: an external directory containing the data files.
        - input_config: JSON string of example_gen_pb2.Input instance,
          providing input configuration.
        - output_config: JSON string of example_gen_pb2.Output instance,
          providing output configuration.
        - output_data_format: Payload format of generated data in output
          artifact, one of example_gen_pb2.PayloadFormat enum.

    Returns:
      None
    """
    self._log_startup(input_dict, output_dict, exec_properties)

    logging.info('Generating examples.')
    with self._make_beam_pipeline() as pipeline:
      example_splits = self.GenerateExamplesByBeam(pipeline, exec_properties)

      # pylint: disable=expression-not-assigned, no-value-for-parameter
      for split_name, example_split in example_splits.items():
        (example_split
         | 'WriteSplit[{}]'.format(split_name) >> _WriteSplit(
             artifact_utils.get_split_uri(output_dict[utils.EXAMPLES_KEY],
                                          split_name)))
      # pylint: enable=expression-not-assigned, no-value-for-parameter

    output_payload_format = exec_properties.get(utils.OUTPUT_DATA_FORMAT_KEY)
    if output_payload_format:
      for output_examples_artifact in output_dict[utils.EXAMPLES_KEY]:
        examples_utils.set_payload_format(
            output_examples_artifact, output_payload_format)
    logging.info('Examples generated.')
Пример #9
0
 def test_set_payload_format_invalid_artifact_type(self):
     artifact = standard_artifacts.Schema()
     with self.assertRaises(AssertionError):
         examples_utils.set_payload_format(
             artifact, example_gen_pb2.PayloadFormat.FORMAT_PROTO)