Example #1
0
 def _test_dataset(self,
                   inputs,
                   expected_output=None,
                   expected_err_re=None,
                   **kwargs):
     """Checks that elements produced by CsvDataset match expected output."""
     # Convert str type because py3 tf strings are bytestrings
     filenames = self.setup_files(inputs)
     with ops.Graph().as_default() as g:
         with self.test_session(graph=g) as sess:
             dataset = readers.CsvDataset(filenames, **kwargs)
             nxt = dataset.make_one_shot_iterator().get_next()
             if expected_err_re is None:
                 # Verify that output is expected, without errors
                 expected_output = [[
                     v.encode('utf-8') if isinstance(v, str) else v
                     for v in op
                 ] for op in expected_output]
                 for value in expected_output:
                     op = sess.run(nxt)
                     self.assertAllEqual(op, value)
                 with self.assertRaises(errors.OutOfRangeError):
                     sess.run(nxt)
             else:
                 # Verify that OpError is produced as expected
                 with self.assertRaisesOpError(expected_err_re):
                     while True:
                         try:
                             sess.run(nxt)
                         except errors.OutOfRangeError:
                             break
Example #2
0
 def testCsvDataset_ignoreErrWithUnquotedQuotes(self):
   record_defaults = [['']] * 3
   inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']]
   filenames = self._setup_files(inputs)
   dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
   dataset = dataset.apply(error_ops.ignore_errors())
   self._verify_output_or_err(dataset, [['e', 'f', 'g']])
 def _make_test_datasets(self, inputs, **kwargs):
     # Test by comparing its output to what we could get with map->decode_csv
     filenames = self._setup_files(inputs)
     dataset_expected = core_readers.TextLineDataset(filenames)
     dataset_expected = dataset_expected.map(
         lambda l: parsing_ops.decode_csv(l, **kwargs))
     dataset_actual = readers.CsvDataset(filenames, **kwargs)
     return (dataset_actual, dataset_expected)
 def benchmarkCsvDatasetWithStrings(self):
     self._setUp(self.STR_VAL)
     for i in range(len(self._filenames)):
         num_cols = self._num_cols[i]
         kwargs = {'record_defaults': [['']] * num_cols}
         dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
         dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat()  # pylint: disable=cell-var-from-loop
         self._runBenchmark(dataset, num_cols, 'csv_strings_fused_dataset')
     self._tearDown()
 def testCsvDataset_ignoreErrWithUnquotedQuotes(self):
   record_defaults = [['']] * 3
   inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']]
   filenames = self.setup_files(inputs)
   with ops.Graph().as_default() as g:
     with self.test_session(graph=g) as sess:
       dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
       dataset = dataset.apply(error_ops.ignore_errors())
       self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']])
Example #6
0
 def benchmarkCsvDataset(self):
     self._setUp()
     for i in range(len(self._filenames)):
         num_cols = self._num_cols[i]
         kwargs = {'record_defaults': [[0.0]] * num_cols}
         dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
         dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat()  # pylint: disable=cell-var-from-loop
         dataset = dataset.batch(self._batch_size)
         self._runBenchmark(dataset, num_cols, 'csv_fused_dataset')
     self._tearDown()
    def ds_func(self, **kwargs):
        compression_type = kwargs.get("compression_type", None)
        if compression_type == "GZIP":
            filename = self._compressed
        elif compression_type is None:
            filename = self._filename
        else:
            raise ValueError("Invalid compression type:", compression_type)

        return readers.CsvDataset(filename, **kwargs).repeat(self._num_epochs)
 def _test_dataset(self,
                   inputs,
                   expected_output=None,
                   expected_err_re=None,
                   linebreak='\n',
                   **kwargs):
   """Checks that elements produced by CsvDataset match expected output."""
   # Convert str type because py3 tf strings are bytestrings
   filenames = self.setup_files(inputs, linebreak)
   with ops.Graph().as_default() as g:
     with self.test_session(graph=g) as sess:
       dataset = readers.CsvDataset(filenames, **kwargs)
       self._verify_output_or_err(sess, dataset, expected_output,
                                  expected_err_re)
 def _test_dataset(
         self,
         inputs,
         expected_output=None,
         expected_err_re=None,
         linebreak='\n',
         compression_type=None,  # Used for both setup and parsing
         **kwargs):
     """Checks that elements produced by CsvDataset match expected output."""
     # Convert str type because py3 tf strings are bytestrings
     filenames = self._setup_files(inputs, linebreak, compression_type)
     kwargs['compression_type'] = compression_type
     dataset = readers.CsvDataset(filenames, **kwargs)
     self._verify_output_or_err(dataset, expected_output, expected_err_re)