def _test_dataset( self, inputs, expected_output=None, expected_err_re=None, linebreak='\n', compression_type=None, # Used for both setup and parsing **kwargs): """Checks that elements produced by CsvDataset match expected output.""" # Convert str type because py3 tf strings are bytestrings filenames = self._setup_files(inputs, linebreak, compression_type) kwargs['compression_type'] = compression_type if expected_err_re is not None: # Verify that OpError is produced as expected with self.assertRaisesOpError(expected_err_re): dataset = readers.CsvDataset(filenames, **kwargs) self.getDatasetOutput(dataset) else: dataset = readers.CsvDataset(filenames, **kwargs) expected_output = [ tuple( v.encode('utf-8') if isinstance(v, str) else v for v in op) for op in expected_output ] self.assertDatasetProduces(dataset, expected_output)
def testIgnoreErrWithUnquotedQuotes(self): record_defaults = [['']] * 3 inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']] filenames = self._setup_files(inputs) dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) dataset = dataset.apply(error_ops.ignore_errors()) self.assertDatasetProduces(dataset, [(b'e', b'f', b'g')])
def testCsvDataset_ignoreErrWithUnquotedQuotes(self): record_defaults = [['']] * 3 inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']] filenames = self._setup_files(inputs) dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) dataset = dataset.apply(error_ops.ignore_errors()) self._verify_output_or_err(dataset, [['e', 'f', 'g']])
def _make_test_datasets(self, inputs, **kwargs): # Test by comparing its output to what we could get with map->decode_csv filenames = self._setup_files(inputs) dataset_expected = core_readers.TextLineDataset(filenames) dataset_expected = dataset_expected.map( lambda l: parsing_ops.decode_csv(l, **kwargs)) dataset_actual = readers.CsvDataset(filenames, **kwargs) return (dataset_actual, dataset_expected)
def testCsvDataset_ignoreErrWithUnescapedQuotes(self): record_defaults = [['']] * 3 inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']] filenames = self._setup_files(inputs) dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) dataset = dataset.apply(error_ops.ignore_errors()) self.assertDatasetProduces(dataset, [(b'e', b'f', b'g')])
def benchmarkCsvDatasetWithStrings(self): self._setUp(self.STR_VAL) for i in range(len(self._filenames)): num_cols = self._num_cols[i] kwargs = {'record_defaults': [['']] * num_cols} dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop self._runBenchmark(dataset, num_cols, 'csv_strings_fused_dataset') self._tearDown()
def benchmark_csv_dataset_with_floats(self): self._set_up(self.FLOAT_VAL) for i in range(len(self._filenames)): num_cols = self._num_cols[i] kwargs = {'record_defaults': [[0.0]] * num_cols} dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop self._run_benchmark(dataset, num_cols, 'csv_float_fused_dataset') self._tear_down()
def ds_func(self, **kwargs): compression_type = kwargs.get("compression_type", None) if compression_type == "GZIP": filename = self._compressed elif compression_type is None: filename = self._filename else: raise ValueError("Invalid compression type:", compression_type) return readers.CsvDataset(filename, **kwargs).repeat(self._num_epochs)
def benchmark_csv_dataset_with_strings(self): self._set_up(self.STR_VAL) for i in range(len(self._filenames)): num_cols = self._num_cols[i] kwargs = {'record_defaults': [['']] * num_cols} dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop self._run_benchmark(dataset=dataset, num_cols=num_cols, prefix='csv_strings_fused_dataset', benchmark_id=4) self._tear_down()
def _test_dataset( self, inputs, expected_output=None, expected_err_re=None, linebreak='\n', compression_type=None, # Used for both setup and parsing **kwargs): """Checks that elements produced by CsvDataset match expected output.""" # Convert str type because py3 tf strings are bytestrings filenames = self._setup_files(inputs, linebreak, compression_type) kwargs['compression_type'] = compression_type if expected_err_re is not None: # Verify that OpError is produced as expected with self.assertRaisesOpError(expected_err_re): dataset = readers.CsvDataset(filenames, **kwargs) self._verify_output_or_err(dataset, expected_output, expected_err_re) else: dataset = readers.CsvDataset(filenames, **kwargs) self._verify_output_or_err(dataset, expected_output, expected_err_re)