def setup(self): if self._serialized_schema: self._decoder = example_coder.ExamplesToRecordBatchDecoder( self._serialized_schema, self._produce_large_types) else: self._decoder = example_coder.ExamplesToRecordBatchDecoder( self._produce_large_types)
def test_decode_large_types(self, schema_text_proto, examples_text_proto, create_expected): serialized_examples = [ text_format.Parse(pbtxt, tf.train.Example()).SerializeToString() for pbtxt in examples_text_proto ] serialized_schema = None if schema_text_proto is not None: serialized_schema = text_format.Parse( schema_text_proto, schema_pb2.Schema()).SerializeToString() if serialized_schema: coder = example_coder.ExamplesToRecordBatchDecoder( serialized_schema=serialized_schema, use_large_types=True) else: coder = example_coder.ExamplesToRecordBatchDecoder( use_large_types=True) result = coder.DecodeBatch(serialized_examples) self.assertIsInstance(result, pa.RecordBatch) expected = create_expected(pa.large_list, pa.large_binary()) self.assertTrue(result.equals(expected), "actual: {}\n expected:{}".format(result, expected)) if serialized_schema: self.assertTrue(expected.schema.equals(coder.ArrowSchema()))
def _ArrowSchemaNoRawRecordColumn(self) -> pa.Schema: schema = self._GetSchemaForDecoding() if schema is None: raise ValueError("TFMD schema not provided. Unable to derive an " "Arrow schema") return example_coder.ExamplesToRecordBatchDecoder( schema.SerializeToString()).ArrowSchema()
def test_invalid_input(self, schema_text_proto, examples_text_proto, error, error_msg_regex): serialized_examples = [ text_format.Parse(pbtxt, tf.train.Example()).SerializeToString() for pbtxt in examples_text_proto ] serialized_schema = None if schema_text_proto is not None: serialized_schema = text_format.Parse( schema_text_proto, schema_pb2.Schema()).SerializeToString() if serialized_schema: coder = example_coder.ExamplesToRecordBatchDecoder(serialized_schema) else: coder = example_coder.ExamplesToRecordBatchDecoder() with self.assertRaisesRegex(error, error_msg_regex): coder.DecodeBatch(serialized_examples)
def _validate_sql(sql_query: Text, schema: schema_pb2.Schema): arrow_schema = example_coder.ExamplesToRecordBatchDecoder( schema.SerializeToString()).ArrowSchema() formatted_query = slicing_util.format_slice_sql_query(sql_query) try: sql_util.RecordBatchSQLSliceQuery(formatted_query, arrow_schema) except Exception as e: # pylint: disable=broad-except raise ValueError('One of the slice SQL query %s raised an exception: %s.' % (sql_query, repr(e)))
def setup(self): super(_KerasEvaluateCombiner, self).setup() # TODO(b/180125126): Re-enable use of passed in TensorAdapter after bug # requiring matching schema's is fixed. # if self._tensor_adapter is None and # self._tensor_adapter_config is not None: # self._tensor_adapter = tensor_adapter.TensorAdapter( # self._tensor_adapter_config) if self._decoder is None: self._decoder = example_coder.ExamplesToRecordBatchDecoder()
def test_invalid_feature_type(self): serialized_schema = text_format.Parse( """ feature { name: "a" type: STRUCT } """, schema_pb2.Schema()).SerializeToString() with self.assertRaisesRegex(RuntimeError, "Bad field type for feature: a.*"): _ = example_coder.ExamplesToRecordBatchDecoder(serialized_schema)
def test_decode(self, schema_text_proto, examples_text_proto, expected): serialized_examples = [ text_format.Merge(pbtxt, tf.train.Example()).SerializeToString() for pbtxt in examples_text_proto ] serialized_schema = None if schema_text_proto is not None: serialized_schema = text_format.Merge( schema_text_proto, schema_pb2.Schema()).SerializeToString() if serialized_schema: coder = example_coder.ExamplesToRecordBatchDecoder( serialized_schema) else: coder = example_coder.ExamplesToRecordBatchDecoder() result = coder.DecodeBatch(serialized_examples) self.assertIsInstance(result, pa.RecordBatch) self.assertTrue(result.equals(expected), "actual: {}\n expected:{}".format(result, expected))
def RecordBatches( self, options: dataset_options.RecordBatchesOptions ) -> Iterator[pa.RecordBatch]: dataset = dataset_util.make_tf_record_dataset( self._file_pattern, options.batch_size, options.drop_final_batch, options.num_epochs, options.shuffle, options.shuffle_buffer_size, options.shuffle_seed) decoder = example_coder.ExamplesToRecordBatchDecoder( self._schema.SerializeToString()) for examples in dataset.as_numpy_iterator(): decoded = decoder.DecodeBatch(examples) if self._raw_record_column_name is None: yield decoded else: yield record_based_tfxio.AppendRawRecordColumn( decoded, self._raw_record_column_name, examples.tolist())
def _readDatasetIntoBatchedExtracts(self): """Read the raw dataset and massage examples into batched Extracts.""" serialized_examples = list( self._dataset.read_raw_dataset(deserialize=False, limit=self._max_num_examples())) # TODO(b/153996019): Once the TFXIO interface that returns an iterator of # RecordBatch is available, clean this up. coder = example_coder.ExamplesToRecordBatchDecoder( serialized_schema=benchmark_utils.read_schema( self._dataset.tf_metadata_schema_path()).SerializeToString()) batches = [] for i in range(0, len(serialized_examples), _BATCH_SIZE): example_batch = serialized_examples[i:i + _BATCH_SIZE] record_batch = record_based_tfxio.AppendRawRecordColumn( coder.DecodeBatch(example_batch), constants.ARROW_INPUT_COLUMN, example_batch) batches.append({constants.ARROW_RECORD_BATCH_KEY: record_batch}) return batches
def _get_batched_records(dataset, max_num_examples=None): """Returns a (batch_size, iterator for batched records) tuple for the dataset. Args: dataset: BenchmarkDataset object. max_num_examples: Maximum number of examples to read from the dataset. Returns: Tuple of (batch_size, iterator for batched records), where records are decoded tf.train.Examples. """ batch_size = 1000 common_variables = _get_common_variables(dataset) converter = example_coder.ExamplesToRecordBatchDecoder( common_variables.transform_input_dataset_metadata.schema. SerializeToString()) serialized_records = benchmark_utils.batched_iterator( dataset.read_raw_dataset(deserialize=False, limit=max_num_examples), batch_size) records = [converter.DecodeBatch(x) for x in serialized_records] return batch_size, records
def test_arrow_schema_not_available_if_tfmd_schema_not_available(self): coder = example_coder.ExamplesToRecordBatchDecoder() with self.assertRaisesRegex(RuntimeError, "Unable to get the arrow schema"): _ = coder.ArrowSchema()
def setup(self): self._decoder = example_coder.ExamplesToRecordBatchDecoder()
def setup(self): if self._serialized_schema: self._decoder = example_coder.ExamplesToRecordBatchDecoder( self._serialized_schema) else: self._decoder = example_coder.ExamplesToRecordBatchDecoder()
def setup(self): if self._schema: self._decoder = example_coder.ExamplesToRecordBatchDecoder( self._schema.SerializeToString()) else: self._decoder = example_coder.ExamplesToRecordBatchDecoder()