def setUp(self): super(TFRecordWriterTest, self).setUp() self._num_records = 7 self.filename = array_ops.placeholder(dtypes.string, shape=[]) self.compression_type = array_ops.placeholder_with_default("", shape=[]) input_dataset = readers.TFRecordDataset([self.filename], self.compression_type) self.writer = writers.TFRecordWriter( self._outputFilename(), self.compression_type).write(input_dataset)
def testFailDataset(self): with self.assertRaises(TypeError): writers.TFRecordWriter(self._outputFilename(), self.compression_type).write("whoops")
def testFailShape(self): input_dataset = dataset_ops.Dataset.from_tensors([["hello"], ["world"]]) with self.assertRaises(TypeError): writers.TFRecordWriter(self._outputFilename(), self.compression_type).write(input_dataset)
def writer_fn(self, filename, compression_type=""): input_dataset = readers.TFRecordDataset([filename], compression_type) return writers.TFRecordWriter(self._outputFilename(), compression_type).write(input_dataset)
def reduce_func(key, dataset): shard_filename = string_ops.string_join( [filename, string_ops.as_string(key)]) writer = writers.TFRecordWriter(shard_filename) writer.write(dataset.map(lambda _, x: x)) return dataset_ops.Dataset.from_tensors(shard_filename)
def writer_fn(): input_dataset = readers.TFRecordDataset(self._createFile()) return writers.TFRecordWriter( self._outputFilename()).write(input_dataset)
def testFailDType(self): input_dataset = dataset_ops.Dataset.from_tensors(10) with self.assertRaises(TypeError): writers.TFRecordWriter(self._outputFilename(), "").write(input_dataset)