Exemplo n.º 1
0
 def testParse(self):
   parser = csv_parser.CSVParser(column_names=["col0", "col1", "col2"],
                                 default_values=["", "", 1.4])
   csv_lines = ["one,two,2.5", "four,five,6.0"]
   csv_input = tf.constant(csv_lines, dtype=tf.string, shape=[len(csv_lines)])
   csv_column = mocks.MockSeries("csv", csv_input)
   expected_output = [np.array([b"one", b"four"]),
                      np.array([b"two", b"five"]),
                      np.array([2.5, 6.0])]
   output_columns = parser(csv_column)
   self.assertEqual(3, len(output_columns))
   cache = {}
   output_tensors = [o.build(cache) for o in output_columns]
   self.assertEqual(3, len(output_tensors))
   with self.test_session() as sess:
     output = sess.run(output_tensors)
     for expected, actual in zip(expected_output, output):
       np.testing.assert_array_equal(actual, expected)
Exemplo n.º 2
0
  def _from_csv_base(cls, filepatterns, get_default_values, has_header,
                     column_names, num_epochs, num_threads, enqueue_size,
                     batch_size, queue_capacity, min_after_dequeue, shuffle,
                     seed):
    """Create a `DataFrame` from `tensorflow.Example`s.

    If `has_header` is false, then `column_names` must be specified. If
    `has_header` is true and `column_names` are specified, then `column_names`
    overrides the names in the header.

    Args:
      filepatterns: a list of file patterns that resolve to CSV files.
      get_default_values: a function that produces a list of default values for
        each column, given the column names.
      has_header: whether or not the CSV files have headers.
      column_names: a list of names for the columns in the CSV files.
      num_epochs: the number of times that the reader should loop through all
        the file names. If set to `None`, then the reader will continue
        indefinitely.
      num_threads: the number of readers that will work in parallel.
      enqueue_size: block size for each read operation.
      batch_size: desired batch size.
      queue_capacity: capacity of the queue that will store parsed lines.
      min_after_dequeue: minimum number of elements that can be left by a
        dequeue operation. Only used if `shuffle` is true.
      shuffle: whether records should be shuffled. Defaults to true.
      seed: passed to random shuffle operations. Only used if `shuffle` is true.

    Returns:
      A `DataFrame` that has columns corresponding to `features` and is filled
      with `Example`s from `filepatterns`.

    Raises:
      ValueError: no files match `filepatterns`.
      ValueError: `features` contains the reserved name 'index'.
    """
    filenames = _expand_file_names(filepatterns)
    if not filenames:
      raise ValueError("No matching file names.")

    if column_names is None:
      if not has_header:
        raise ValueError("If column_names is None, has_header must be true.")
      with gfile.GFile(filenames[0]) as f:
        column_names = csv.DictReader(f).fieldnames

    if "index" in column_names:
      raise ValueError(
          "'index' is reserved and can not be used for a column name.")

    default_values = get_default_values(column_names)

    reader_kwargs = {"skip_header_lines": (1 if has_header else 0)}
    index, value = reader_source.TextFileSource(
        filenames,
        reader_kwargs=reader_kwargs,
        enqueue_size=enqueue_size,
        batch_size=batch_size,
        num_epochs=num_epochs,
        queue_capacity=queue_capacity,
        shuffle=shuffle,
        min_after_dequeue=min_after_dequeue,
        num_threads=num_threads,
        seed=seed)()
    parser = csv_parser.CSVParser(column_names, default_values)
    parsed = parser(value)

    column_dict = parsed._asdict()
    column_dict["index"] = index

    dataframe = cls()
    dataframe.assign(**column_dict)
    return dataframe