def test_get_tfxio_factory_from_artifact(self, payload_format, expected_tfxio_type, raw_record_column_name=None, provide_data_view_uri=False, read_as_raw_records=False): examples = standard_artifacts.Examples() if payload_format is not None: examples_utils.set_payload_format(examples, payload_format) data_view_uri = None if provide_data_view_uri: data_view_uri = tempfile.mkdtemp(dir=self.get_temp_dir()) tf_graph_record_decoder.save_decoder(_SimpleTfGraphRecordDecoder(), data_view_uri) if data_view_uri is not None: examples.set_string_custom_property( constants.DATA_VIEW_URI_PROPERTY_KEY, data_view_uri) tfxio_factory = tfxio_utils.get_tfxio_factory_from_artifact( examples, _TELEMETRY_DESCRIPTORS, _SCHEMA, read_as_raw_records, raw_record_column_name) tfxio = tfxio_factory(_FAKE_FILE_PATTERN) self.assertIsInstance(tfxio, expected_tfxio_type) # We currently only create RecordBasedTFXIO and the check below relies on # that. self.assertIsInstance(tfxio, record_based_tfxio.RecordBasedTFXIO) self.assertEqual(tfxio.telemetry_descriptors, _TELEMETRY_DESCRIPTORS) self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name) # Since we provide a schema, ArrowSchema() should not raise. _ = tfxio.ArrowSchema()
def test_save_load_decode(self): decoder = _DecoderForTesting() self.assertEqual( decoder.output_type_specs(), { "sparse_tensor": tf.SparseTensorSpec(shape=[None, None], dtype=tf.string), "ragged_tensor": tf.RaggedTensorSpec( shape=[None, None], dtype=tf.string, ragged_rank=1) }) tf_graph_record_decoder.save_decoder(decoder, self._tmp_dir) loaded = tf_graph_record_decoder.load_decoder(self._tmp_dir) self.assertEqual(decoder.output_type_specs(), loaded.output_type_specs()) got = loaded.decode_record([b"abc", b"def"]) self.assertLen(got, len(loaded.output_type_specs())) self.assertIn("sparse_tensor", got) st = got["sparse_tensor"] self.assertAllEqual(st.values, [b"abc", b"def"]) self.assertAllEqual(st.indices, [[0, 0], [1, 0]]) self.assertAllEqual(st.dense_shape, [2, 1]) rt = got["ragged_tensor"] self.assertAllEqual(rt, tf.ragged.constant([[b"abc"], [b"def"]]))
def test_get_tfxio_factory_from_artifact_data_view_legacy(self): # This tests FORMAT_PROTO with data view where the DATA_VIEW_CREATE_TIME_KEY # is an int value. This is a legacy property type and should be string type # in the future. if tf.__version__ < '2': self.skipTest('DataView is not supported under TF 1.x.') examples = standard_artifacts.Examples() examples_utils.set_payload_format( examples, example_gen_pb2.PayloadFormat.FORMAT_PROTO) data_view_uri = tempfile.mkdtemp(dir=self.get_temp_dir()) tf_graph_record_decoder.save_decoder(_SimpleTfGraphRecordDecoder(), data_view_uri) examples.set_string_custom_property(constants.DATA_VIEW_URI_PROPERTY_KEY, data_view_uri) examples.set_int_custom_property(constants.DATA_VIEW_CREATE_TIME_KEY, '1') tfxio_factory = tfxio_utils.get_tfxio_factory_from_artifact( [examples], _TELEMETRY_DESCRIPTORS, _SCHEMA, read_as_raw_records=False, raw_record_column_name=None) tfxio = tfxio_factory(_FAKE_FILE_PATTERN) self.assertIsInstance(tfxio, record_to_tensor_tfxio.TFRecordToTensorTFXIO) # We currently only create RecordBasedTFXIO and the check below relies on # that. self.assertIsInstance(tfxio, record_based_tfxio.RecordBasedTFXIO) self.assertEqual(tfxio.telemetry_descriptors, _TELEMETRY_DESCRIPTORS) # Since we provide a schema, ArrowSchema() should not raise. _ = tfxio.ArrowSchema()
def test_make_tfxio(self, payload_format, expected_tfxio_type, raw_record_column_name=None, provide_data_view_uri=False, read_as_raw_records=False): if payload_format is None: payload_format = 'FORMAT_TF_EXAMPLE' data_view_uri = None if provide_data_view_uri: data_view_uri = tempfile.mkdtemp(dir=self.get_temp_dir()) tf_graph_record_decoder.save_decoder(_SimpleTfGraphRecordDecoder(), data_view_uri) tfxio = tfxio_utils.make_tfxio(_FAKE_FILE_PATTERN, _TELEMETRY_DESCRIPTORS, payload_format, data_view_uri, _SCHEMA, read_as_raw_records, raw_record_column_name) self.assertIsInstance(tfxio, expected_tfxio_type) # We currently only create RecordBasedTFXIO and the check below relies on # that. self.assertIsInstance(tfxio, record_based_tfxio.RecordBasedTFXIO) self.assertEqual(tfxio.telemetry_descriptors, _TELEMETRY_DESCRIPTORS) self.assertEqual(tfxio.raw_record_column_name, raw_record_column_name) # Since we provide a schema, ArrowSchema() should not raise. _ = tfxio.ArrowSchema()
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: self._log_startup(input_dict, output_dict, exec_properties) create_decoder_func = udf_utils.get_fn(exec_properties, _CREATE_DECODER_FUNC_KEY) tf_graph_record_decoder.save_decoder( create_decoder_func(), value_utils.GetSoleValue(output_dict, _DATA_VIEW_KEY).uri)
def setUp(self): super(RecordToTensorTfxioTest, self).setUp() unique_dir = uuid.uuid4().hex self._decoder_path = os.path.join( FLAGS.test_tmpdir, "recordtotensortfxiotest", unique_dir) tf_graph_record_decoder.save_decoder( _DecoderForTesting(), self._decoder_path) self._input_path = os.path.join( FLAGS.test_tmpdir, "recordtotensortfxiotest", unique_dir, "input") _write_input(self._input_path)
def test_save_load_decode(self): decoder = _DecoderForTestWithRecordIndexTensorName() self.assertEqual( decoder.output_type_specs(), { "sparse_tensor": tf.SparseTensorSpec(shape=[None, None], dtype=tf.string), "ragged_tensor": tf.RaggedTensorSpec( shape=[None, None], dtype=tf.string, ragged_rank=1), "record_index": tf.RaggedTensorSpec( shape=[None, None], dtype=tf.int64, ragged_rank=1), "dense_tensor": tf.TensorSpec(shape=[None], dtype=tf.string) }) self.assertEqual(decoder.record_index_tensor_name, "record_index") tf_graph_record_decoder.save_decoder(decoder, self._tmp_dir) loaded = tf_graph_record_decoder.load_decoder(self._tmp_dir) self.assertEqual(loaded.record_index_tensor_name, "record_index") self._assert_type_specs_equal(decoder.output_type_specs(), loaded.output_type_specs()) records = [b"abc", b"def"] got = loaded.decode_record(records) self.assertLen(got, len(loaded.output_type_specs())) self.assertIn("sparse_tensor", got) st = got["sparse_tensor"] self.assertAllEqual(st.values, records) self.assertAllEqual(st.indices, [[0, 0], [1, 0]]) self.assertAllEqual(st.dense_shape, [2, 1]) rt = got["ragged_tensor"] self.assertAllEqual(rt, tf.ragged.constant([[b"abc"], [b"def"]])) rt = got["record_index"] self.assertAllEqual(rt, tf.ragged.constant([[0], [1]])) dt = got["dense_tensor"] self.assertAllEqual(dt, records) # Also test that .record_index_tensor_name can be accessed in graph # mode. with tf.compat.v1.Graph().as_default(): self.assertFalse(tf.executing_eagerly()) loaded = tf_graph_record_decoder.load_decoder(self._tmp_dir) self.assertEqual(loaded.record_index_tensor_name, "record_index") # Also test that the decoder's class method `save_decoder` works. new_decoder_path = (os.path.join(self._tmp_dir, "decoder_2")) decoder.save(new_decoder_path) loaded = tf_graph_record_decoder.load_decoder(new_decoder_path) self.assertEqual(loaded.record_index_tensor_name, "record_index")
def test_no_record_index_tensor_name(self): decoder = _DecoderForTesting() self.assertIsNone(decoder.record_index_tensor_name) tf_graph_record_decoder.save_decoder(decoder, self._tmp_dir) loaded = tf_graph_record_decoder.load_decoder(self._tmp_dir) self.assertIsNone(loaded.record_index_tensor_name) with tf.compat.v1.Graph().as_default(): self.assertFalse(tf.executing_eagerly()) loaded = tf_graph_record_decoder.load_decoder(self._tmp_dir) self.assertIsNone(loaded.record_index_tensor_name)
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: del input_dict if _MODULE_FILE_KEY in exec_properties: create_decoder_func = import_utils.import_func_from_source( exec_properties.get(_MODULE_FILE_KEY), exec_properties.get(_CREATE_DECODER_FUNC_KEY)) else: create_decoder_func = udf_utils.get_fn(exec_properties, _CREATE_DECODER_FUNC_KEY) tf_graph_record_decoder.save_decoder( create_decoder_func(), value_utils.GetSoleValue(output_dict, _DATA_VIEW_KEY).uri)
def _write_decoder(decoder=_DecoderForTesting()): result = tempfile.mkdtemp(dir=FLAGS.test_tmpdir) tf_graph_record_decoder.save_decoder(decoder, result) return result
def test_do_not_save_if_record_index_tensor_name_invalid(self): decoder = _DecoderForTestWithInvalidRecordIndexTensorName() with self.assertRaisesRegex(AssertionError, "record_index_tensor_name"): tf_graph_record_decoder.save_decoder(decoder, self._tmp_dir)