def _test_dataset(self, inputs, expected_output=None, expected_err_re=None, **kwargs): """Checks that elements produced by CsvDataset match expected output.""" # Convert str type because py3 tf strings are bytestrings filenames = self.setup_files(inputs) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = readers.CsvDataset(filenames, **kwargs) nxt = dataset.make_one_shot_iterator().get_next() if expected_err_re is None: # Verify that output is expected, without errors expected_output = [[ v.encode('utf-8') if isinstance(v, str) else v for v in op ] for op in expected_output] for value in expected_output: op = sess.run(nxt) self.assertAllEqual(op, value) with self.assertRaises(errors.OutOfRangeError): sess.run(nxt) else: # Verify that OpError is produced as expected with self.assertRaisesOpError(expected_err_re): while True: try: sess.run(nxt) except errors.OutOfRangeError: break
def testCsvDataset_ignoreErrWithUnquotedQuotes(self): record_defaults = [['']] * 3 inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']] filenames = self._setup_files(inputs) dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) dataset = dataset.apply(error_ops.ignore_errors()) self._verify_output_or_err(dataset, [['e', 'f', 'g']])
def _make_test_datasets(self, inputs, **kwargs): # Test by comparing its output to what we could get with map->decode_csv filenames = self._setup_files(inputs) dataset_expected = core_readers.TextLineDataset(filenames) dataset_expected = dataset_expected.map( lambda l: parsing_ops.decode_csv(l, **kwargs)) dataset_actual = readers.CsvDataset(filenames, **kwargs) return (dataset_actual, dataset_expected)
def benchmarkCsvDatasetWithStrings(self): self._setUp(self.STR_VAL) for i in range(len(self._filenames)): num_cols = self._num_cols[i] kwargs = {'record_defaults': [['']] * num_cols} dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop self._runBenchmark(dataset, num_cols, 'csv_strings_fused_dataset') self._tearDown()
def testCsvDataset_ignoreErrWithUnquotedQuotes(self): record_defaults = [['']] * 3 inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']] filenames = self.setup_files(inputs) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) dataset = dataset.apply(error_ops.ignore_errors()) self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']])
def benchmarkCsvDataset(self): self._setUp() for i in range(len(self._filenames)): num_cols = self._num_cols[i] kwargs = {'record_defaults': [[0.0]] * num_cols} dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop dataset = dataset.batch(self._batch_size) self._runBenchmark(dataset, num_cols, 'csv_fused_dataset') self._tearDown()
def ds_func(self, **kwargs): compression_type = kwargs.get("compression_type", None) if compression_type == "GZIP": filename = self._compressed elif compression_type is None: filename = self._filename else: raise ValueError("Invalid compression type:", compression_type) return readers.CsvDataset(filename, **kwargs).repeat(self._num_epochs)
def _test_dataset(self, inputs, expected_output=None, expected_err_re=None, linebreak='\n', **kwargs): """Checks that elements produced by CsvDataset match expected output.""" # Convert str type because py3 tf strings are bytestrings filenames = self.setup_files(inputs, linebreak) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = readers.CsvDataset(filenames, **kwargs) self._verify_output_or_err(sess, dataset, expected_output, expected_err_re)
def _test_dataset( self, inputs, expected_output=None, expected_err_re=None, linebreak='\n', compression_type=None, # Used for both setup and parsing **kwargs): """Checks that elements produced by CsvDataset match expected output.""" # Convert str type because py3 tf strings are bytestrings filenames = self._setup_files(inputs, linebreak, compression_type) kwargs['compression_type'] = compression_type dataset = readers.CsvDataset(filenames, **kwargs) self._verify_output_or_err(dataset, expected_output, expected_err_re)