예제 #1
0
 def setup(self):
     if self._serialized_schema:
         self._decoder = example_coder.ExamplesToRecordBatchDecoder(
             self._serialized_schema, self._produce_large_types)
     else:
         self._decoder = example_coder.ExamplesToRecordBatchDecoder(
             self._produce_large_types)
예제 #2
0
    def test_decode_large_types(self, schema_text_proto, examples_text_proto,
                                create_expected):
        serialized_examples = [
            text_format.Parse(pbtxt, tf.train.Example()).SerializeToString()
            for pbtxt in examples_text_proto
        ]
        serialized_schema = None
        if schema_text_proto is not None:
            serialized_schema = text_format.Parse(
                schema_text_proto, schema_pb2.Schema()).SerializeToString()

        if serialized_schema:
            coder = example_coder.ExamplesToRecordBatchDecoder(
                serialized_schema=serialized_schema, use_large_types=True)
        else:
            coder = example_coder.ExamplesToRecordBatchDecoder(
                use_large_types=True)

        result = coder.DecodeBatch(serialized_examples)
        self.assertIsInstance(result, pa.RecordBatch)
        expected = create_expected(pa.large_list, pa.large_binary())
        self.assertTrue(result.equals(expected),
                        "actual: {}\n expected:{}".format(result, expected))
        if serialized_schema:
            self.assertTrue(expected.schema.equals(coder.ArrowSchema()))
예제 #3
0
 def _ArrowSchemaNoRawRecordColumn(self) -> pa.Schema:
     schema = self._GetSchemaForDecoding()
     if schema is None:
         raise ValueError("TFMD schema not provided. Unable to derive an "
                          "Arrow schema")
     return example_coder.ExamplesToRecordBatchDecoder(
         schema.SerializeToString()).ArrowSchema()
예제 #4
0
  def test_invalid_input(self, schema_text_proto, examples_text_proto, error,
                         error_msg_regex):
    serialized_examples = [
        text_format.Parse(pbtxt, tf.train.Example()).SerializeToString()
        for pbtxt in examples_text_proto
    ]
    serialized_schema = None
    if schema_text_proto is not None:
      serialized_schema = text_format.Parse(
          schema_text_proto, schema_pb2.Schema()).SerializeToString()

    if serialized_schema:
      coder = example_coder.ExamplesToRecordBatchDecoder(serialized_schema)
    else:
      coder = example_coder.ExamplesToRecordBatchDecoder()

    with self.assertRaisesRegex(error, error_msg_regex):
      coder.DecodeBatch(serialized_examples)
예제 #5
0
def _validate_sql(sql_query: Text, schema: schema_pb2.Schema):
  arrow_schema = example_coder.ExamplesToRecordBatchDecoder(
      schema.SerializeToString()).ArrowSchema()
  formatted_query = slicing_util.format_slice_sql_query(sql_query)
  try:
    sql_util.RecordBatchSQLSliceQuery(formatted_query, arrow_schema)
  except Exception as e:  # pylint: disable=broad-except
    raise ValueError('One of the slice SQL query %s raised an exception: %s.'
                     % (sql_query, repr(e)))
예제 #6
0
 def setup(self):
     super(_KerasEvaluateCombiner, self).setup()
     # TODO(b/180125126): Re-enable use of passed in TensorAdapter after bug
     # requiring matching schema's is fixed.
     # if self._tensor_adapter is None and
     #    self._tensor_adapter_config is not None:
     #   self._tensor_adapter = tensor_adapter.TensorAdapter(
     #       self._tensor_adapter_config)
     if self._decoder is None:
         self._decoder = example_coder.ExamplesToRecordBatchDecoder()
예제 #7
0
 def test_invalid_feature_type(self):
   serialized_schema = text_format.Parse(
       """
       feature {
         name: "a"
         type: STRUCT
       }
       """, schema_pb2.Schema()).SerializeToString()
   with self.assertRaisesRegex(RuntimeError,
                               "Bad field type for feature: a.*"):
     _ = example_coder.ExamplesToRecordBatchDecoder(serialized_schema)
예제 #8
0
    def test_decode(self, schema_text_proto, examples_text_proto, expected):
        serialized_examples = [
            text_format.Merge(pbtxt, tf.train.Example()).SerializeToString()
            for pbtxt in examples_text_proto
        ]
        serialized_schema = None
        if schema_text_proto is not None:
            serialized_schema = text_format.Merge(
                schema_text_proto, schema_pb2.Schema()).SerializeToString()

        if serialized_schema:
            coder = example_coder.ExamplesToRecordBatchDecoder(
                serialized_schema)
        else:
            coder = example_coder.ExamplesToRecordBatchDecoder()

        result = coder.DecodeBatch(serialized_examples)
        self.assertIsInstance(result, pa.RecordBatch)
        self.assertTrue(result.equals(expected),
                        "actual: {}\n expected:{}".format(result, expected))
예제 #9
0
    def RecordBatches(
        self, options: dataset_options.RecordBatchesOptions
    ) -> Iterator[pa.RecordBatch]:
        dataset = dataset_util.make_tf_record_dataset(
            self._file_pattern, options.batch_size, options.drop_final_batch,
            options.num_epochs, options.shuffle, options.shuffle_buffer_size,
            options.shuffle_seed)

        decoder = example_coder.ExamplesToRecordBatchDecoder(
            self._schema.SerializeToString())
        for examples in dataset.as_numpy_iterator():
            decoded = decoder.DecodeBatch(examples)
            if self._raw_record_column_name is None:
                yield decoded
            else:
                yield record_based_tfxio.AppendRawRecordColumn(
                    decoded, self._raw_record_column_name, examples.tolist())
예제 #10
0
    def _readDatasetIntoBatchedExtracts(self):
        """Read the raw dataset and massage examples into batched Extracts."""
        serialized_examples = list(
            self._dataset.read_raw_dataset(deserialize=False,
                                           limit=self._max_num_examples()))

        # TODO(b/153996019): Once the TFXIO interface that returns an iterator of
        # RecordBatch is available, clean this up.
        coder = example_coder.ExamplesToRecordBatchDecoder(
            serialized_schema=benchmark_utils.read_schema(
                self._dataset.tf_metadata_schema_path()).SerializeToString())
        batches = []
        for i in range(0, len(serialized_examples), _BATCH_SIZE):
            example_batch = serialized_examples[i:i + _BATCH_SIZE]
            record_batch = record_based_tfxio.AppendRawRecordColumn(
                coder.DecodeBatch(example_batch), constants.ARROW_INPUT_COLUMN,
                example_batch)
            batches.append({constants.ARROW_RECORD_BATCH_KEY: record_batch})
        return batches
예제 #11
0
def _get_batched_records(dataset, max_num_examples=None):
    """Returns a (batch_size, iterator for batched records) tuple for the dataset.

  Args:
    dataset: BenchmarkDataset object.
    max_num_examples: Maximum number of examples to read from the dataset.

  Returns:
    Tuple of (batch_size, iterator for batched records), where records are
    decoded tf.train.Examples.
  """
    batch_size = 1000
    common_variables = _get_common_variables(dataset)
    converter = example_coder.ExamplesToRecordBatchDecoder(
        common_variables.transform_input_dataset_metadata.schema.
        SerializeToString())
    serialized_records = benchmark_utils.batched_iterator(
        dataset.read_raw_dataset(deserialize=False, limit=max_num_examples),
        batch_size)
    records = [converter.DecodeBatch(x) for x in serialized_records]
    return batch_size, records
예제 #12
0
 def test_arrow_schema_not_available_if_tfmd_schema_not_available(self):
   coder = example_coder.ExamplesToRecordBatchDecoder()
   with self.assertRaisesRegex(RuntimeError, "Unable to get the arrow schema"):
     _ = coder.ArrowSchema()
예제 #13
0
 def setup(self):
   self._decoder = example_coder.ExamplesToRecordBatchDecoder()
예제 #14
0
 def setup(self):
     if self._serialized_schema:
         self._decoder = example_coder.ExamplesToRecordBatchDecoder(
             self._serialized_schema)
     else:
         self._decoder = example_coder.ExamplesToRecordBatchDecoder()
예제 #15
0
 def setup(self):
     if self._schema:
         self._decoder = example_coder.ExamplesToRecordBatchDecoder(
             self._schema.SerializeToString())
     else:
         self._decoder = example_coder.ExamplesToRecordBatchDecoder()