Ejemplo n.º 1
0
 def test_get_tfxio_factory_from_artifact(self,
                                          payload_format,
                                          expected_tfxio_type,
                                          raw_record_column_name=None,
                                          provide_data_view_uri=False,
                                          read_as_raw_records=False):
     examples = standard_artifacts.Examples()
     if payload_format is not None:
         examples_utils.set_payload_format(examples, payload_format)
     data_view_uri = None
     if provide_data_view_uri:
         data_view_uri = tempfile.mkdtemp(dir=self.get_temp_dir())
         tf_graph_record_decoder.save_decoder(_SimpleTfGraphRecordDecoder(),
                                              data_view_uri)
     if data_view_uri is not None:
         examples.set_string_custom_property(
             constants.DATA_VIEW_URI_PROPERTY_KEY, data_view_uri)
     tfxio_factory = tfxio_utils.get_tfxio_factory_from_artifact(
         examples, _TELEMETRY_DESCRIPTORS, _SCHEMA, read_as_raw_records,
         raw_record_column_name)
     tfxio = tfxio_factory(_FAKE_FILE_PATTERN)
     self.assertIsInstance(tfxio, expected_tfxio_type)
     # We currently only create RecordBasedTFXIO and the check below relies on
     # that.
     self.assertIsInstance(tfxio, record_based_tfxio.RecordBasedTFXIO)
     self.assertEqual(tfxio.telemetry_descriptors, _TELEMETRY_DESCRIPTORS)
     self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name)
     # Since we provide a schema, ArrowSchema() should not raise.
     _ = tfxio.ArrowSchema()
    def test_save_load_decode(self):
        decoder = _DecoderForTesting()
        self.assertEqual(
            decoder.output_type_specs(), {
                "sparse_tensor":
                tf.SparseTensorSpec(shape=[None, None], dtype=tf.string),
                "ragged_tensor":
                tf.RaggedTensorSpec(
                    shape=[None, None], dtype=tf.string, ragged_rank=1)
            })
        tf_graph_record_decoder.save_decoder(decoder, self._tmp_dir)
        loaded = tf_graph_record_decoder.load_decoder(self._tmp_dir)

        self.assertEqual(decoder.output_type_specs(),
                         loaded.output_type_specs())
        got = loaded.decode_record([b"abc", b"def"])
        self.assertLen(got, len(loaded.output_type_specs()))
        self.assertIn("sparse_tensor", got)
        st = got["sparse_tensor"]
        self.assertAllEqual(st.values, [b"abc", b"def"])
        self.assertAllEqual(st.indices, [[0, 0], [1, 0]])
        self.assertAllEqual(st.dense_shape, [2, 1])

        rt = got["ragged_tensor"]
        self.assertAllEqual(rt, tf.ragged.constant([[b"abc"], [b"def"]]))
Ejemplo n.º 3
0
  def test_get_tfxio_factory_from_artifact_data_view_legacy(self):
    # This tests FORMAT_PROTO with data view where the DATA_VIEW_CREATE_TIME_KEY
    # is an int value. This is a legacy property type and should be string type
    # in the future.
    if tf.__version__ < '2':
      self.skipTest('DataView is not supported under TF 1.x.')

    examples = standard_artifacts.Examples()
    examples_utils.set_payload_format(
        examples, example_gen_pb2.PayloadFormat.FORMAT_PROTO)
    data_view_uri = tempfile.mkdtemp(dir=self.get_temp_dir())
    tf_graph_record_decoder.save_decoder(_SimpleTfGraphRecordDecoder(),
                                         data_view_uri)
    examples.set_string_custom_property(constants.DATA_VIEW_URI_PROPERTY_KEY,
                                        data_view_uri)
    examples.set_int_custom_property(constants.DATA_VIEW_CREATE_TIME_KEY, '1')
    tfxio_factory = tfxio_utils.get_tfxio_factory_from_artifact(
        [examples],
        _TELEMETRY_DESCRIPTORS,
        _SCHEMA,
        read_as_raw_records=False,
        raw_record_column_name=None)
    tfxio = tfxio_factory(_FAKE_FILE_PATTERN)
    self.assertIsInstance(tfxio, record_to_tensor_tfxio.TFRecordToTensorTFXIO)
    # We currently only create RecordBasedTFXIO and the check below relies on
    # that.
    self.assertIsInstance(tfxio, record_based_tfxio.RecordBasedTFXIO)
    self.assertEqual(tfxio.telemetry_descriptors, _TELEMETRY_DESCRIPTORS)
    # Since we provide a schema, ArrowSchema() should not raise.
    _ = tfxio.ArrowSchema()
Ejemplo n.º 4
0
 def test_make_tfxio(self,
                     payload_format,
                     expected_tfxio_type,
                     raw_record_column_name=None,
                     provide_data_view_uri=False,
                     read_as_raw_records=False):
     if payload_format is None:
         payload_format = 'FORMAT_TF_EXAMPLE'
     data_view_uri = None
     if provide_data_view_uri:
         data_view_uri = tempfile.mkdtemp(dir=self.get_temp_dir())
         tf_graph_record_decoder.save_decoder(_SimpleTfGraphRecordDecoder(),
                                              data_view_uri)
     tfxio = tfxio_utils.make_tfxio(_FAKE_FILE_PATTERN,
                                    _TELEMETRY_DESCRIPTORS, payload_format,
                                    data_view_uri, _SCHEMA,
                                    read_as_raw_records,
                                    raw_record_column_name)
     self.assertIsInstance(tfxio, expected_tfxio_type)
     # We currently only create RecordBasedTFXIO and the check below relies on
     # that.
     self.assertIsInstance(tfxio, record_based_tfxio.RecordBasedTFXIO)
     self.assertEqual(tfxio.telemetry_descriptors, _TELEMETRY_DESCRIPTORS)
     self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name)
     # Since we provide a schema, ArrowSchema() should not raise.
     _ = tfxio.ArrowSchema()
Ejemplo n.º 5
0
 def Do(self, input_dict: Dict[Text, List[types.Artifact]],
        output_dict: Dict[Text, List[types.Artifact]],
        exec_properties: Dict[Text, Any]) -> None:
     self._log_startup(input_dict, output_dict, exec_properties)
     create_decoder_func = udf_utils.get_fn(exec_properties,
                                            _CREATE_DECODER_FUNC_KEY)
     tf_graph_record_decoder.save_decoder(
         create_decoder_func(),
         value_utils.GetSoleValue(output_dict, _DATA_VIEW_KEY).uri)
  def setUp(self):
    super(RecordToTensorTfxioTest, self).setUp()
    unique_dir = uuid.uuid4().hex
    self._decoder_path = os.path.join(
        FLAGS.test_tmpdir, "recordtotensortfxiotest", unique_dir)
    tf_graph_record_decoder.save_decoder(
        _DecoderForTesting(), self._decoder_path)

    self._input_path = os.path.join(
        FLAGS.test_tmpdir, "recordtotensortfxiotest", unique_dir, "input")
    _write_input(self._input_path)
Ejemplo n.º 7
0
    def test_save_load_decode(self):
        decoder = _DecoderForTestWithRecordIndexTensorName()
        self.assertEqual(
            decoder.output_type_specs(), {
                "sparse_tensor":
                tf.SparseTensorSpec(shape=[None, None], dtype=tf.string),
                "ragged_tensor":
                tf.RaggedTensorSpec(
                    shape=[None, None], dtype=tf.string, ragged_rank=1),
                "record_index":
                tf.RaggedTensorSpec(
                    shape=[None, None], dtype=tf.int64, ragged_rank=1),
                "dense_tensor":
                tf.TensorSpec(shape=[None], dtype=tf.string)
            })
        self.assertEqual(decoder.record_index_tensor_name, "record_index")
        tf_graph_record_decoder.save_decoder(decoder, self._tmp_dir)
        loaded = tf_graph_record_decoder.load_decoder(self._tmp_dir)
        self.assertEqual(loaded.record_index_tensor_name, "record_index")

        self._assert_type_specs_equal(decoder.output_type_specs(),
                                      loaded.output_type_specs())

        records = [b"abc", b"def"]
        got = loaded.decode_record(records)
        self.assertLen(got, len(loaded.output_type_specs()))
        self.assertIn("sparse_tensor", got)
        st = got["sparse_tensor"]
        self.assertAllEqual(st.values, records)
        self.assertAllEqual(st.indices, [[0, 0], [1, 0]])
        self.assertAllEqual(st.dense_shape, [2, 1])

        rt = got["ragged_tensor"]
        self.assertAllEqual(rt, tf.ragged.constant([[b"abc"], [b"def"]]))

        rt = got["record_index"]
        self.assertAllEqual(rt, tf.ragged.constant([[0], [1]]))

        dt = got["dense_tensor"]
        self.assertAllEqual(dt, records)

        # Also test that .record_index_tensor_name can be accessed in graph
        # mode.
        with tf.compat.v1.Graph().as_default():
            self.assertFalse(tf.executing_eagerly())
            loaded = tf_graph_record_decoder.load_decoder(self._tmp_dir)
            self.assertEqual(loaded.record_index_tensor_name, "record_index")

        # Also test that the decoder's class method `save_decoder` works.
        new_decoder_path = (os.path.join(self._tmp_dir, "decoder_2"))
        decoder.save(new_decoder_path)
        loaded = tf_graph_record_decoder.load_decoder(new_decoder_path)
        self.assertEqual(loaded.record_index_tensor_name, "record_index")
Ejemplo n.º 8
0
    def test_no_record_index_tensor_name(self):
        decoder = _DecoderForTesting()
        self.assertIsNone(decoder.record_index_tensor_name)

        tf_graph_record_decoder.save_decoder(decoder, self._tmp_dir)
        loaded = tf_graph_record_decoder.load_decoder(self._tmp_dir)
        self.assertIsNone(loaded.record_index_tensor_name)

        with tf.compat.v1.Graph().as_default():
            self.assertFalse(tf.executing_eagerly())
            loaded = tf_graph_record_decoder.load_decoder(self._tmp_dir)
            self.assertIsNone(loaded.record_index_tensor_name)
Ejemplo n.º 9
0
 def Do(self, input_dict: Dict[Text, List[types.Artifact]],
        output_dict: Dict[Text, List[types.Artifact]],
        exec_properties: Dict[Text, Any]) -> None:
     del input_dict
     if _MODULE_FILE_KEY in exec_properties:
         create_decoder_func = import_utils.import_func_from_source(
             exec_properties.get(_MODULE_FILE_KEY),
             exec_properties.get(_CREATE_DECODER_FUNC_KEY))
     else:
         create_decoder_func = udf_utils.get_fn(exec_properties,
                                                _CREATE_DECODER_FUNC_KEY)
     tf_graph_record_decoder.save_decoder(
         create_decoder_func(),
         value_utils.GetSoleValue(output_dict, _DATA_VIEW_KEY).uri)
Ejemplo n.º 10
0
def _write_decoder(decoder=_DecoderForTesting()):
  result = tempfile.mkdtemp(dir=FLAGS.test_tmpdir)
  tf_graph_record_decoder.save_decoder(decoder, result)
  return result
Ejemplo n.º 11
0
 def test_do_not_save_if_record_index_tensor_name_invalid(self):
     decoder = _DecoderForTestWithInvalidRecordIndexTensorName()
     with self.assertRaisesRegex(AssertionError,
                                 "record_index_tensor_name"):
         tf_graph_record_decoder.save_decoder(decoder, self._tmp_dir)