Exemple #1
0
    def _input_fn():

        #file_names = data.Dataset.list_files(files_pattern)
        file_names = tf.matching_files(files_pattern)

        if Params.EAGER:
            print(file_names)

        dataset = data.TextLineDataset(file_names)

        dataset = dataset.apply(
            tf.contrib.data.shuffle_and_repeat(count=num_epochs,
                                               buffer_size=batch_size * 2))

        dataset = dataset.apply(
            tf.contrib.data.map_and_batch(parse_tsv,
                                          batch_size=batch_size,
                                          num_parallel_batches=2))

        datset = dataset.prefetch(batch_size)

        if Params.EAGER:
            return dataset

        iterator = dataset.make_one_shot_iterator()
        features, target = iterator.get_next()
        return features, target
def create_labeled_data_iterator_with_context(context, txt1, txt2, labels,
                                              vocab_table, batch_size):
    context_dataset = create_wordindex_with_length_dataset(
        context, vocab_table)
    text1_dataset = create_wordindex_with_length_dataset(txt1, vocab_table, -1)
    text2_dataset = create_wordindex_with_length_dataset(txt2, vocab_table, -1)

    # Labels is a single float
    labels_dataset = data.TextLineDataset(labels)
    labels_dataset = labels_dataset.map(lambda line: tf.string_to_number(line))
    labels_dataset = labels_dataset.map(
        lambda label: tf.cast(label, tf.float32))

    dataset = data.Dataset.zip(
        (context_dataset, text1_dataset, text2_dataset, labels_dataset))

    # Separate out lengths of txt1 and txt2
    dataset = dataset.map(lambda ctx, t1, t2, label:
                          (ctx[0], t1[0], t2[0], ctx[1], t1[1], t2[1], label))

    # Create a padded batch
    dataset = dataset.padded_batch(batch_size,
                                   padded_shapes=(tf.TensorShape([None]),
                                                  tf.TensorShape([None]),
                                                  tf.TensorShape([None]),
                                                  tf.TensorShape([]),
                                                  tf.TensorShape([]),
                                                  tf.TensorShape([]),
                                                  tf.TensorShape([])))

    iterator = dataset.make_initializable_iterator()
    ctx, txt1, txt2, len_ctx, len_txt1, len_txt2, label = iterator.get_next()

    return BatchedCtxInput(ctx, txt1, txt2, len_ctx, len_txt1, len_txt2, label,
                           iterator.initializer)
def input_fn(files_name_pattern, mode=tf.estimator.ModeKeys.EVAL,
                 skip_header_lines=0,
                 num_epochs=1,
                 batch_size=512):

    shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False
    num_threads = multiprocessing.cpu_count() if pm.MULTI_THREADING else 1

    # representing the number of elements from this dataset from which the new dataset will sample.
    buffer_size_prefetch = 2 * batch_size + 1
    buffer_size_shuffle = pm.TRAIN_SIZE #10 * batch_size + 1

    file_names = tf.matching_files(files_name_pattern)
    dataset = data.TextLineDataset(filenames=file_names)
    dataset = dataset.skip(skip_header_lines)

    if shuffle:
        dataset = dataset.shuffle(buffer_size_shuffle)

    dataset = dataset.map(lambda tsv_row: parse_tsv_row(tsv_row), num_parallel_calls=num_threads)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.prefetch(buffer_size_prefetch)
    iterator = dataset.make_one_shot_iterator()
    features, target = iterator.get_next()
    return features, parse_label_column(target)
Exemple #4
0
def input_fn(file_name_pattern,
             mode=tf.estimator.ModeKeys.EVAL,
             skip_header_lines=1,
             num_epochs=1,
             batch_size=200):
    shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False
    num_threads = multiprocessing.cpu_count()
    buffer_size = 2 * batch_size + 1
    print("")
    print("* data input_fn:")
    print("================")
    print("Input file(s): {}".format(file_name_pattern))
    print("Batch size: {}".format(batch_size))
    print("Epoch Count: {}".format(num_epochs))
    print("Mode: {}".format(mode))
    print("Thread Count: {}".format(num_threads))
    print("Shuffle: {}".format(shuffle))
    print("================")
    print("")
    dataset = data.TextLineDataset(filenames=file_name_pattern)
    dataset = dataset.skip(skip_header_lines)
    if shuffle:
        dataset = dataset.shuffle(buffer_size)
    dataset = dataset.map(lambda tsv_row: parse_tsv_row(tsv_row),
                          num_parallel_calls=num_threads)
    dataset = dataset.batch(batch_size)
    if mode == tf.estimator.ModeKeys.TRAIN:
        dataset = dataset.repeat(None)
    else:
        dataset = dataset.repeat(1)
    dataset = dataset.prefetch(buffer_size)
    iterator = dataset.make_one_shot_iterator()
    features, target = iterator.get_next()
    return features, parse_label_column(target)
Exemple #5
0
    def _input_fn():

        shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False

        data_size = HYPER_PARAMS.train_size if mode == tf.estimator.ModeKeys.TRAIN else None

        num_threads = multiprocessing.cpu_count() if multi_threading else 1

        buffer_size = 2 * batch_size + 1

        tf.compat.v1.logging.info("")
        tf.compat.v1.logging.info("* datasets input_fn:")
        tf.compat.v1.logging.info("================")
        tf.compat.v1.logging.info(("Mode: {}".format(mode)))
        tf.compat.v1.logging.info(
            ("Input file(s): {}".format(file_names_pattern)))
        tf.compat.v1.logging.info(("Files encoding: {}".format(file_encoding)))
        tf.compat.v1.logging.info(("Data size: {}".format(data_size)))
        tf.compat.v1.logging.info(("Batch size: {}".format(batch_size)))
        tf.compat.v1.logging.info(("Epoch count: {}".format(num_epochs)))
        tf.compat.v1.logging.info(("Thread count: {}".format(num_threads)))
        tf.compat.v1.logging.info(("Shuffle: {}".format(shuffle)))
        tf.compat.v1.logging.info("================")
        tf.compat.v1.logging.info("")

        file_names = tf.io.matching_files(file_names_pattern)

        if file_encoding == 'csv':
            dataset = data.TextLineDataset(filenames=file_names)
            dataset = dataset.skip(skip_header_lines)
            dataset = dataset.map(lambda csv_row: parse_csv(csv_row))

        else:
            dataset = data.TFRecordDataset(filenames=file_names)
            dataset = dataset.map(lambda tf_example: parse_tf_example(
                tf_example, HYPER_PARAMS, mode=mode),
                                  num_parallel_calls=num_threads)

        dataset = dataset.map(
            lambda features: get_features_target_tuple(features, mode=mode),
            num_parallel_calls=num_threads)
        dataset = dataset.map(lambda features, target: (process_features(
            features, HYPER_PARAMS=HYPER_PARAMS, mode=mode), target),
                              num_parallel_calls=num_threads)

        if shuffle:
            dataset = dataset.shuffle(buffer_size)

        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(buffer_size)
        dataset = dataset.repeat(num_epochs)

        iterator = dataset.make_one_shot_iterator()
        features, target = iterator.get_next()

        features, target = posprocessing(features, target, mode, HYPER_PARAMS)

        return features, target
Exemple #6
0
    def _input_fn():

        shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False

        data_size = task.HYPER_PARAMS.train_size if mode == tf.estimator.ModeKeys.TRAIN else None

        num_threads = multiprocessing.cpu_count() if multi_threading else 1

        buffer_size = 2 * batch_size + 1

        print("")
        print("* data input_fn:")
        print("================")
        print("Mode: {}".format(mode))
        print("Input file(s): {}".format(file_names_pattern))
        print("Files encoding: {}".format(file_encoding))
        print("Data size: {}".format(data_size))
        print("Batch size: {}".format(batch_size))
        print("Epoch count: {}".format(num_epochs))
        print("Thread count: {}".format(num_threads))
        print("Shuffle: {}".format(shuffle))
        print("================")
        print("")

        file_names = tf.matching_files(file_names_pattern)

        if file_encoding == 'csv':
            dataset = data.TextLineDataset(filenames=file_names)
            dataset = dataset.skip(skip_header_lines)
            dataset = dataset.map(lambda csv_row: parse_csv(csv_row)
                                  )  # just append column_names to data

        else:
            dataset = data.TFRecordDataset(filenames=file_names)
            dataset = dataset.map(
                lambda tf_example: parse_tf_example(tf_example),
                num_parallel_calls=num_threads)

        dataset = dataset.map(
            lambda features: get_features_target_tuple(
                features),  # features here is a dictionary
            num_parallel_calls=num_threads)
        dataset = dataset.map(lambda features, target:
                              (process_features(features), target),
                              num_parallel_calls=num_threads)

        if shuffle:
            dataset = dataset.shuffle(buffer_size)

        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(buffer_size)
        dataset = dataset.repeat(num_epochs)

        iterator = dataset.make_one_shot_iterator()
        features, target = iterator.get_next()

        return features, target  # features: dictionary, target: value
def create_wordindex_with_length_dataset(dataset_file_path,
                                         vocab_table,
                                         max_len=160):
    text_dataset = data.TextLineDataset(dataset_file_path)
    text_dataset = text_dataset.map(
        lambda line: tf.string_split([line]).values)
    text_dataset = text_dataset.map(lambda words: vocab_table.lookup(words))
    if max_len > 0:
        text_dataset = text_dataset.map(lambda words: words[-max_len:])
    text_dataset = text_dataset.map(lambda words:
                                    (tf.cast(words, tf.int32), tf.size(words)))
    return text_dataset
    def __init__(self,
                 sample_file_path,
                 offset_file_path=offset_path,
                 is_training=True,
                 batch_size=2,
                 shuffle=True,
                 buffer_size=500,
                 same_prob=0.5,
                 inference_mode=False):
        """Creates a new ImageDataGenerator.
        
        Receives a path string to a text file, which consists of many lines, where each line specifies the relative
        location of an image. Using this data, this class will create TensorFlow dataset that can be used to train
        rectifynet.
        
        :param sample_file_path: Path to the sample csv file
        :param offset_path: Path to the offset file for record random retrieval
        :param mode: A boolean value indicating "train" or "validation" status. Depending on this value, pre-processing
            is done differently.
        :param batch_size: Number of images per batch.
        :param shuffle: Whether or not to shuffle the data in the dataset and the initial file list.
        :param buffer_size: Number of image dirs used as buffer for TensorFlows shuffling of the dataset. 
            If not specified, the entire txt_file will be buffered into memory for shuffling.
        """
        self.sample_file_path = sample_file_path
        self.offset_file_path = offset_file_path
        self.is_training = is_training
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.same_prob = same_prob
        self.prod_indices = range(batch_size)
        self.inference_mode = inference_mode
        if not self.inference_mode:
            self._read_csv_file()
        dataset = data.TextLineDataset(sample_file_path).skip(1)

        if shuffle and not self.inference_mode:
            dataset = dataset.shuffle(buffer_size=self.buffer_size)

        if not self.inference_mode:
            dataset = dataset.map(
                lambda row: tf.py_func(self._data_augment, [row, True], [
                    tf.float32, tf.bool, tf.float32, tf.bool, tf.int32, tf.
                    int16, tf.int16
                ]))
        else:
            dataset = dataset.map(lambda row: tf.py_func(
                self._data_augment, [row, False],
                [tf.float32, tf.bool, tf.int32, tf.int32]))

        self.data = dataset.batch(self.batch_size)
Exemple #9
0
def csv_input_fn(file_names,
                 mode=tf.estimator.ModeKeys.EVAL,
                 skip_header_lines=0,
                 num_epochs=None,
                 batch_size=200):
    # 训练阶段数据要shuffle,测试阶段不用
    shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False
    # 多线程
    num_threads = multiprocessing.cpu_count() if MULTI_THREADING else 1
    # 输出信息
    print("")
    print("数据输入函数input_fn:")
    print("================")
    print("输入文件: {}".format(file_names))
    print("Batch size: {}".format(batch_size))
    print("Epoch Count: {}".format(num_epochs))
    print("模式: {}".format(mode))
    print("Thread Count: {}".format(num_threads))
    print("Shuffle: {}".format(shuffle))
    print("================")
    print("")

    # file_names = tf.matching_files(files_name_pattern)
    dataset = data.TextLineDataset(filenames=file_names)
    # 跳过第一行
    dataset = dataset.skip(skip_header_lines)
    # 乱序
    if shuffle:
        dataset = dataset.shuffle(buffer_size=2 * batch_size + 1)
    # 取一个batch
    dataset = dataset.batch(batch_size)
    # 对数据进行解析
    dataset = dataset.map(lambda csv_row: parse_csv_row(csv_row),
                          num_parallel_calls=num_threads)
    # 如果做更多处理,添加新列
    if PROCESS_FEATURES:
        dataset = dataset.map(lambda features, target:
                              (process_features(features), target),
                              num_parallel_calls=num_threads)
    # 每个epoch完成后,重启dataset
    dataset = dataset.repeat(num_epochs)
    iterator = dataset.make_one_shot_iterator()
    # 取出满足 特征字典+结果序列 的值
    features, target = iterator.get_next()
    return features, target
Exemple #10
0
def csv_input_fn(files_name_pattern,
                 mode=tf.estimator.ModeKeys.EVAL,
                 skip_header_lines=0,
                 num_epochs=1,
                 batch_size=20):
    shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False

    num_threads = multiprocessing.cpu_count() if MULTI_THREADING else 1

    print("")
    print("* data input_fn:")
    print("================")
    print("Input file(s): {}".format(files_name_pattern))
    print("Batch size: {}".format(batch_size))
    print("Epoch Count: {}".format(num_epochs))
    print("Mode: {}".format(mode))
    print("Shuffle: {}".format(shuffle))
    print("================")
    print("")

    file_names = tf.matching_files(files_name_pattern)

    dataset = data.TextLineDataset(filenames=file_names)
    dataset = dataset.skip(skip_header_lines)

    if shuffle:
        dataset = dataset.shuffle(buffer_size=2 * batch_size + 1)

    dataset = dataset.batch(batch_size)
    dataset = dataset.map(lambda csv_row: parse_csv_row(csv_row),
                          num_parallel_calls=num_threads)

    # dataset = dataset.batch(batch_size) #??? very long time
    dataset = dataset.repeat(num_epochs)
    iterator = dataset.make_one_shot_iterator()

    features, target = iterator.get_next()
    return features, target
def csv_input_fn(files_name_pattern,
                 mode=tf.estimator.ModeKeys.EVAL,
                 skip_header_lines=0,
                 num_epochs=1,
                 batch_size=20):

    file_names = tf.matching_files(files_name_pattern)
    dataset = data.TextLineDataset(filenames=file_names)
    dataset = dataset.skip(skip_header_lines)
    num_threads = multiprocessing.cpu_count() if MULTI_THREADING else 1
    if mode == tf.estimator.ModeKeys.TRAIN:
        dataset = dataset.shuffle(buffer_size=2 * batch_size + 1)

    dataset = dataset.batch(batch_size)
    dataset = dataset.map(lambda csv_row: parse_csv_row(csv_row),
                          num_parallel_calls=num_threads)

    # dataset = dataset.batch(batch_size) #??? very long time
    dataset = dataset.repeat(num_epochs)
    iterator = dataset.make_one_shot_iterator()

    features, target = iterator.get_next()
    return features, target
Exemple #12
0
def dataset_input_fn(file_names_pattern,
                     file_encoding='csv',
                     mode=tf.estimator.ModeKeys.EVAL,
                     skip_header_lines=0,
                     num_epochs=1,
                     batch_size=200,
                     multi_threading=True):
    """An input function for training or evaluation.
    This uses the Dataset APIs.

    Args:
        file_names_pattern: [str] - file name or file name patterns from which to read the data.
        mode: tf.estimator.ModeKeys - either TRAIN or EVAL.
            Used to determine whether or not to randomize the order of data.
        file_encoding: type of the text files. Can be 'csv' or 'tfrecords'
        skip_header_lines: int set to non-zero in order to skip header lines
          in CSV files.
        num_epochs: int - how many times through to read the data.
          If None will loop through data indefinitely
        batch_size: int - first dimension size of the Tensors returned by
          input_fn
        multi_threading: boolean - indicator to use multi-threading or not
    Returns:
        A function () -> (features, indices) where features is a dictionary of
          Tensors, and indices is a single Tensor of label indices.
    """

    shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False

    data_size = parameters.HYPER_PARAMS.train_size if mode == tf.estimator.ModeKeys.TRAIN else None

    num_threads = multiprocessing.cpu_count() if multi_threading else 1

    buffer_size = 2 * batch_size + 1

    print("")
    print("* data input_fn:")
    print("================")
    print("Mode: {}".format(mode))
    print("Input file(s): {}".format(file_names_pattern))
    print("Files encoding: {}".format(file_encoding))
    print("Data size: {}".format(data_size))
    print("Batch size: {}".format(batch_size))
    print("Epoch Count: {}".format(num_epochs))
    print("Thread Count: {}".format(num_threads))
    print("Shuffle: {}".format(shuffle))
    print("================")
    print("")

    file_names = tf.matching_files(file_names_pattern)

    if file_encoding == 'csv':
        dataset = data.TextLineDataset(filenames=file_names)
        dataset = dataset.skip(skip_header_lines)
        dataset = dataset.map(lambda csv_row: parse_csv(csv_row))

    else:
        dataset = data.TFRecordDataset(filenames=file_names)
        dataset = dataset.map(lambda tf_example: parse_tf_example(tf_example),
                              num_parallel_calls=num_threads)

    dataset = dataset.map(lambda features: get_features_target_tuple(features),
                          num_parallel_calls=num_threads)
    dataset = dataset.map(lambda features, target: (process_features(features), target),
                          num_parallel_calls=num_threads)

    if shuffle:
        dataset = dataset.shuffle(buffer_size)

    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size)
    dataset = dataset.repeat(num_epochs)

    iterator = dataset.make_one_shot_iterator()
    features, target = iterator.get_next()

    return features, target
Exemple #13
0
def commonDataSet(data_type,
                  tensor=None,
                  record_bytes=None,
                  filename=None,
                  file_pattern=None,
                  generator=None,
                  output_types=None,
                  end_point=None,
                  step=1,
                  output_shapes=None,
                  args=None,
                  compression_type=None,
                  buffer_size=None,
                  num_parallel_reads=None,
                  header_bytes=None,
                  footer_bytes=None,
                  start_point=0,
                  datasets=None):
    assert data_type in {
        "file_list", "TFrecord", "binary", "tensor_slices", "tensor",
        "generator", "range", "zip", "debug"
    }

    if data_type == "file_line":  # 文件中的行
        DataSet = data.TextLineDataset(filenames=filename,
                                       compression_type=compression_type,
                                       buffer_size=buffer_size,
                                       num_parallel_reads=num_parallel_reads)
        """
        filenames   使用的字符串tensor 或者 包含多个文件名
        compression_type  可选项 ""(不压缩)   ZLIB  GZIP  字符串   None
        buffer_size 可选, 整形, 表示要缓存额大小, 0 会导致根据亚索类型选择默认的缓冲值  None
        num_parallel_reads  可选, int  表示要病毒读取的文件数。 如果值大于1,  None
        则以交错顺序输出并行读取的文件记录, 如果设为None ,则被顺序读取
        """

    elif data_type == "file_list":
        DataSet = data.Dataset.list_files(file_pattern=file_pattern,
                                          shuffle=None,
                                          seed=None)
        """
        file_pattern  使用的文件列表 或字符串 "/path/*.txt"
        shuffle  是否进行打乱操作  默认为打乱
        seed  随机种子数  int
        """

    elif data_type == "TFrecord":  # 读取TFreacord 文件使用的  tfrecord 文件列表
        DataSet = data.TFRecordDataset(filenames=filename,
                                       compression_type=compression_type,
                                       buffer_size=buffer_size,
                                       num_parallel_reads=num_parallel_reads)
        """
        filenames   使用的字符串tensor 或者 包含多个文件名
        compression_type  可选项 ""(不压缩)   ZLIB  GZIP  字符串   None
        buffer_size 可选, 整形, 表示要缓存额大小, 0 会导致根据亚索类型选择默认的缓冲值  None
        num_parallel_reads  可选, int  表示要并行读取的文件数。 如果值大于1,  None
        则以交错顺序输出并行读取的文件记录, 如果设为None ,则被顺序读取
        """

    elif data_type == "binary":  # CIFAR10数据集使用的数据格式就是这种, 二进制文件
        DataSet = data.FixedLengthRecordDataset(
            filenames=filename,
            record_bytes=record_bytes,
            header_bytes=header_bytes,
            footer_bytes=footer_bytes,
            buffer_size=buffer_size,
            compression_type=compression_type,
            num_parallel_reads=num_parallel_reads)
        """
        filenames   使用的字符串tensor 或者 tf.data.Dataset中包含多个文件名
        record_bytes   tf.int64    数据类型
        header_bytes  表示文件开头要跳过的字节数, 可选
        footer_bytes  表示文件末尾要忽略的字节数
        buffer_size 可选, 整形, 表示要缓存额大小, 0 会导致根据亚索类型选择默认的缓冲值  None
        compression_type  可选项 ""(不压缩)   ZLIB  GZIP  字符串   None       
        num_parallel_reads  可选, int  表示要病毒读取的文件数。 如果值大于1,  None
        则以交错顺序输出并行读取的文件记录, 如果设为None ,则被顺序读取
        """

    elif data_type == "generator":
        DataSet = data.Dataset.from_generator(
            generator=generator,
            output_types=output_types,
            output_shapes=output_shapes,
            args=args,
        )
        """
        generator  , iter迭代对象, 若args未指定,怎, generator不必带参数,否则,它必须于args 中的参数一样多
        output_types 数据类型
        output_shapes 尺寸  可选
        args
        """

    elif data_type == "range":
        DataSet = data.Dataset.range(start_point, end_point, step)

    elif data_type == "zip":  # zip 操作, 对两个或者对个数据集进行合并操作
        DataSet = data.Dataset.zip(datasets=datasets)
        """
        dataset  必须是一个tuple   (datasetA, datasetB)
        """
    elif data_type == "tensor_slices":  # 张量 作用于 from_tensor 的区别, from_tensor_slices的作用是可以将tensor进行切分
        # 切分第一个维, 如 tensor 使用的是 np.random.uniform(size=(5,2)).    即 划分为 5行2列的数据形状。  5行代表5个样本
        # 2列为每个样本的维。
        DataSet = data.Dataset.from_tensor_slices(tensors=tensor)
        # 实际上 是将array作为一个tf.constants 保存到计算图中, 当array比较大的时候,导致计算图非常大, 给传输于保存带来不变。
        # 此时使用placeholder 取代这里的array.  并使用initializable iterator, 只在需要时,将array传进去,
        # 这样可以避免大数组保存在图中
    elif data_type == "debug":
        DataSet = data.Dataset(...)

    else:
        DataSet = data.Dataset.from_tensors(tensors=tensor)

    return DataSet
Exemple #14
0
    def _input_fn():

        shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False

        data_size = task.HYPER_PARAMS.train_size if mode == tf.estimator.ModeKeys.TRAIN else None

        num_threads = multiprocessing.cpu_count() if multi_threading else 1

        buffer_size = 2 * batch_size + 1

        print("")
        print("* data input_fn:")
        print("================")
        print("Mode: {}".format(mode))
        print("Input file(s): {}".format(file_names_pattern))
        print("Files encoding: {}".format(file_encoding))
        print("Data size: {}".format(data_size))
        print("Batch size: {}".format(batch_size))
        print("Epoch count: {}".format(num_epochs))
        print("Thread count: {}".format(num_threads))
        print("Shuffle: {}".format(shuffle))
        print("================")
        print("")

        file_names = tf.matching_files(file_names_pattern)

        if file_encoding == 'csv':
            # Parallely processes num_threads files at the time.
            # Also, since the batch function is called before the map function
            # we use a block lenght of 1 to output one batch from every input
            # file before moving to the next file
            dataset = (data.Dataset.from_tensor_slices(file_names).interleave(
                lambda x: data.TextLineDataset(x).skip(skip_header_lines).
                batch(batch_size).map(parse_csv,
                                      num_parallel_calls=num_threads),
                cycle_length=num_threads,
                block_length=1))

        else:
            dataset = data.TFRecordDataset(filenames=file_names)
            dataset = dataset = dataset.batch(batch_size)
            dataset = dataset.map(
                lambda tf_examples: parse_tf_example(tf_examples),
                num_parallel_calls=num_threads)

        dataset = dataset.map(
            lambda features: get_features_target_tuple(features),
            num_parallel_calls=num_threads)
        dataset = dataset.map(lambda features, target:
                              (process_features(features), target),
                              num_parallel_calls=num_threads)

        if shuffle:
            dataset = dataset.shuffle(buffer_size)

        dataset = dataset.prefetch(buffer_size)
        dataset = dataset.repeat(num_epochs)

        iterator = dataset.make_one_shot_iterator()
        features, target = iterator.get_next()

        return features, target