def testParse(self): parser = csv_parser.CSVParser(column_names=["col0", "col1", "col2"], default_values=["", "", 1.4]) csv_lines = ["one,two,2.5", "four,five,6.0"] csv_input = tf.constant(csv_lines, dtype=tf.string, shape=[len(csv_lines)]) csv_column = mocks.MockSeries("csv", csv_input) expected_output = [np.array([b"one", b"four"]), np.array([b"two", b"five"]), np.array([2.5, 6.0])] output_columns = parser(csv_column) self.assertEqual(3, len(output_columns)) cache = {} output_tensors = [o.build(cache) for o in output_columns] self.assertEqual(3, len(output_tensors)) with self.test_session() as sess: output = sess.run(output_tensors) for expected, actual in zip(expected_output, output): np.testing.assert_array_equal(actual, expected)
def _from_csv_base(cls, filepatterns, get_default_values, has_header, column_names, num_epochs, num_threads, enqueue_size, batch_size, queue_capacity, min_after_dequeue, shuffle, seed): """Create a `DataFrame` from `tensorflow.Example`s. If `has_header` is false, then `column_names` must be specified. If `has_header` is true and `column_names` are specified, then `column_names` overrides the names in the header. Args: filepatterns: a list of file patterns that resolve to CSV files. get_default_values: a function that produces a list of default values for each column, given the column names. has_header: whether or not the CSV files have headers. column_names: a list of names for the columns in the CSV files. num_epochs: the number of times that the reader should loop through all the file names. If set to `None`, then the reader will continue indefinitely. num_threads: the number of readers that will work in parallel. enqueue_size: block size for each read operation. batch_size: desired batch size. queue_capacity: capacity of the queue that will store parsed lines. min_after_dequeue: minimum number of elements that can be left by a dequeue operation. Only used if `shuffle` is true. shuffle: whether records should be shuffled. Defaults to true. seed: passed to random shuffle operations. Only used if `shuffle` is true. Returns: A `DataFrame` that has columns corresponding to `features` and is filled with `Example`s from `filepatterns`. Raises: ValueError: no files match `filepatterns`. ValueError: `features` contains the reserved name 'index'. """ filenames = _expand_file_names(filepatterns) if not filenames: raise ValueError("No matching file names.") if column_names is None: if not has_header: raise ValueError("If column_names is None, has_header must be true.") with gfile.GFile(filenames[0]) as f: column_names = csv.DictReader(f).fieldnames if "index" in column_names: raise ValueError( "'index' is reserved and can not be used for a column name.") default_values = get_default_values(column_names) reader_kwargs = {"skip_header_lines": (1 if has_header else 0)} index, value = reader_source.TextFileSource( filenames, reader_kwargs=reader_kwargs, enqueue_size=enqueue_size, batch_size=batch_size, num_epochs=num_epochs, queue_capacity=queue_capacity, shuffle=shuffle, min_after_dequeue=min_after_dequeue, num_threads=num_threads, seed=seed)() parser = csv_parser.CSVParser(column_names, default_values) parsed = parser(value) column_dict = parsed._asdict() column_dict["index"] = index dataframe = cls() dataframe.assign(**column_dict) return dataframe