def _construct_dataset(record_path, batch_size, sess): def parse_record(serialized_example): # parse a single record features = tf.parse_single_example( serialized_example, features={ 'image_l': tf.FixedLenFeature([IMAGE_SIZE, IMAGE_SIZE, 1], tf.float32), 'image_ab': tf.FixedLenFeature([IMAGE_SIZE, IMAGE_SIZE, 2], tf.float32), 'image_features': tf.FixedLenFeature([1000, ], tf.float32) }) l, ab, embed = features['image_l'], features['image_ab'], features['image_features'] return l, ab, embed dataset = tfdata.TFRecordDataset([record_path], 'ZLIB') # create a Dataset to wrap the TFRecord dataset = dataset.map(parse_record, num_parallel_calls=2) # parse the record dataset = dataset.repeat() # repeat forever dataset = dataset.batch(batch_size) # batch into the required batchsize dataset = dataset.shuffle(buffer_size=5) # shuffle the batches iterator = dataset.make_initializable_iterator() # get an iterator over the dataset sess.run(iterator.initializer) # initialize the iterator next_batch = iterator.get_next() # get the iterator Tensor return dataset, next_batch
def _input_fn(data_dir=TFRECORD_DIR, batch_size=BATCH_SIZE): shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False num_threads = multiprocessing.cpu_count() if multi_threading else 1 buffer_size = 2 * batch_size + 1 file_names = tf.matching_files(data_dir) feature_spec = { 'id': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.float32), 'feat': tf.FixedLenFeature([FEAT_LEN], tf.float32), } dataset = data.TFRecordDataset(filenames=file_names, compression_type = 'GZIP') # dataset = dataset.map(lambda tf_example: printout(tf_example)) dataset = dataset.map(lambda tf_example: tf.parse_example(serialized=[tf_example], features=feature_spec), num_parallel_calls=num_threads) dataset = dataset.map(lambda features: get_features_target_tuple(features), num_parallel_calls=num_threads) if shuffle: dataset = dataset.shuffle(buffer_size) dataset = dataset.batch(batch_size) dataset = dataset.prefetch(buffer_size) dataset = dataset.repeat(num_epochs) iterator = dataset.make_one_shot_iterator() features, target = iterator.get_next() return features, target
def read(data_dir: str, feature_config: FeatureConfig, tfrecord_type: str, file_io: FileIO, max_sequence_size: int = 0, batch_size: int = 0, preprocessing_keys_to_fns: dict = {}, parse_tfrecord: bool = True, use_part_files: bool = False, logger: Logger = None, **kwargs) -> data.TFRecordDataset: """ - reads tfrecord data from an input directory - selects relevant features - creates X and y data Args: data_dir: Path to directory containing csv files to read feature_config: ml4ir.config.features.Features object extracted from the feature config tfrecord_type: either example or sequence_example batch_size: int value specifying the size of the batch preprocessing_keys_to_fns: dictionary mapping preprocessing keys in the feature_config to functions parse_tfrecord: whether to parse SequenceExamples into features logger: logging object Returns: tensorflow dataset """ parse_fn = get_parse_fn( feature_config=feature_config, tfrecord_type=tfrecord_type, preprocessing_keys_to_fns=preprocessing_keys_to_fns, max_sequence_size=max_sequence_size, ) # Get all tfrecord files in directory tfrecord_files = file_io.get_files_in_directory( data_dir, extension="" if use_part_files else ".tfrecord", prefix="part-" if use_part_files else "", ) # Parse the protobuf data to create a TFRecordDataset dataset = data.TFRecordDataset(tfrecord_files) if parse_tfrecord: dataset = dataset.map(parse_fn).apply( data.experimental.ignore_errors()) # Create BatchedDataSet if batch_size: dataset = dataset.batch(batch_size, drop_remainder=True) if logger: logger.info( "Created TFRecordDataset from SequenceExample protobufs from {} files : {}" .format(len(tfrecord_files), str(tfrecord_files)[:50])) return dataset
def tfrecords_input_fn(files_name_pattern, mode=tf.estimator.ModeKeys.EVAL, max_history_len=None, num_epochs=1, batch_size=200): """ padded_batch 를 적용 - dict 내의 각 피쳐에 대해 Shape를 지정 - padding은 기본값인 0.0으로 채워짐 :param files_name_pattern: :param mode: :param num_epochs: :param batch_size: :return: """ shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False num_threads = multiprocessing.cpu_count() if MULTI_THREADING else 1 buffer_size = 2 * batch_size + 1 print("") print("* data input_fn:") print("================") print("Input file(s): {}".format(files_name_pattern)) print("Batch size: {}".format(batch_size)) print("Epoch Count: {}".format(num_epochs)) print("Mode: {}".format(mode)) print("Thread Count: {}".format(num_threads)) print("Shuffle: {}".format(shuffle)) if max_history_len is not None: print("max_history_len: {}".format(max_history_len)) print("================") print("") file_names = tf.matching_files(files_name_pattern) dataset = data.TFRecordDataset(filenames=file_names) if shuffle: dataset = dataset.shuffle(buffer_size) dataset = dataset.map( lambda tf_example: parse_tf_example(tf_example, max_history_len)) dataset = dataset.padded_batch(batch_size, padded_shapes=({ 'uid': [], 'w2vecs': [None, DIM], 'dids': [None], 'sl': [] }, [])) dataset = dataset.repeat(num_epochs) dataset = dataset.prefetch(buffer_size) iterator = dataset.make_one_shot_iterator() features, target = iterator.get_next() return features, target
def _input_fn(): shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False data_size = HYPER_PARAMS.train_size if mode == tf.estimator.ModeKeys.TRAIN else None num_threads = multiprocessing.cpu_count() if multi_threading else 1 buffer_size = 2 * batch_size + 1 tf.compat.v1.logging.info("") tf.compat.v1.logging.info("* datasets input_fn:") tf.compat.v1.logging.info("================") tf.compat.v1.logging.info(("Mode: {}".format(mode))) tf.compat.v1.logging.info( ("Input file(s): {}".format(file_names_pattern))) tf.compat.v1.logging.info(("Files encoding: {}".format(file_encoding))) tf.compat.v1.logging.info(("Data size: {}".format(data_size))) tf.compat.v1.logging.info(("Batch size: {}".format(batch_size))) tf.compat.v1.logging.info(("Epoch count: {}".format(num_epochs))) tf.compat.v1.logging.info(("Thread count: {}".format(num_threads))) tf.compat.v1.logging.info(("Shuffle: {}".format(shuffle))) tf.compat.v1.logging.info("================") tf.compat.v1.logging.info("") file_names = tf.io.matching_files(file_names_pattern) if file_encoding == 'csv': dataset = data.TextLineDataset(filenames=file_names) dataset = dataset.skip(skip_header_lines) dataset = dataset.map(lambda csv_row: parse_csv(csv_row)) else: dataset = data.TFRecordDataset(filenames=file_names) dataset = dataset.map(lambda tf_example: parse_tf_example( tf_example, HYPER_PARAMS, mode=mode), num_parallel_calls=num_threads) dataset = dataset.map( lambda features: get_features_target_tuple(features, mode=mode), num_parallel_calls=num_threads) dataset = dataset.map(lambda features, target: (process_features( features, HYPER_PARAMS=HYPER_PARAMS, mode=mode), target), num_parallel_calls=num_threads) if shuffle: dataset = dataset.shuffle(buffer_size) dataset = dataset.batch(batch_size) dataset = dataset.prefetch(buffer_size) dataset = dataset.repeat(num_epochs) iterator = dataset.make_one_shot_iterator() features, target = iterator.get_next() features, target = posprocessing(features, target, mode, HYPER_PARAMS) return features, target
def _input_fn(): shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False data_size = task.HYPER_PARAMS.train_size if mode == tf.estimator.ModeKeys.TRAIN else None num_threads = multiprocessing.cpu_count() if multi_threading else 1 buffer_size = 2 * batch_size + 1 print("") print("* data input_fn:") print("================") print("Mode: {}".format(mode)) print("Input file(s): {}".format(file_names_pattern)) print("Files encoding: {}".format(file_encoding)) print("Data size: {}".format(data_size)) print("Batch size: {}".format(batch_size)) print("Epoch count: {}".format(num_epochs)) print("Thread count: {}".format(num_threads)) print("Shuffle: {}".format(shuffle)) print("================") print("") file_names = tf.matching_files(file_names_pattern) if file_encoding == 'csv': dataset = data.TextLineDataset(filenames=file_names) dataset = dataset.skip(skip_header_lines) dataset = dataset.map(lambda csv_row: parse_csv(csv_row) ) # just append column_names to data else: dataset = data.TFRecordDataset(filenames=file_names) dataset = dataset.map( lambda tf_example: parse_tf_example(tf_example), num_parallel_calls=num_threads) dataset = dataset.map( lambda features: get_features_target_tuple( features), # features here is a dictionary num_parallel_calls=num_threads) dataset = dataset.map(lambda features, target: (process_features(features), target), num_parallel_calls=num_threads) if shuffle: dataset = dataset.shuffle(buffer_size) dataset = dataset.batch(batch_size) dataset = dataset.prefetch(buffer_size) dataset = dataset.repeat(num_epochs) iterator = dataset.make_one_shot_iterator() features, target = iterator.get_next() return features, target # features: dictionary, target: value
def from_tfrecords(self, files): dataset = data.TFRecordDataset(files) dataset = dataset.map(map_func=self._preprocess_example, num_parallel_calls=self.num_threads) dataset = dataset.repeat(self.repeat) if self.shuffle: dataset = dataset.shuffle(buffer_size=self.shuffle_buffer, seed=self.seed) dataset = dataset.batch(self.batch_size) return dataset
def from_tfrecords(self, files): dataset = data.TFRecordDataset(files) dataset = dataset.map(map_func=self._parse_function, num_parallel_calls=self.num_threads) dataset = dataset.repeat(self.repeat) if self.shuffle: dataset = dataset.shuffle(buffer_size=self.shuffle_buffer, seed=self.seed, reshuffle_each_iteration=True) dataset = dataset.batch(self.batch_size) return dataset
def tfrecords_input_fn(files_name_pattern, feature_spec, label, mode=tf.estimator.ModeKeys.EVAL, num_epochs=None, batch_size=64): shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False file_names = tf.matching_files(files_name_pattern) dataset = data.TFRecordDataset(filenames=file_names) if shuffle: dataset = dataset.shuffle(buffer_size=2 * batch_size + 1) dataset = dataset.batch(batch_size) dataset = dataset.map(lambda tf_example: parse_tf_example(tf_example, label, feature_spec)) dataset = dataset.repeat(num_epochs) return dataset
def read(data_dir: str, feature_config: FeatureConfig, max_num_records: int = 25, batch_size: int = 128, parse_tfrecord: bool = True, use_part_files: bool = False, logger: Logger = None, **kwargs) -> data.TFRecordDataset: """ - reads tfrecord data from an input directory - selects relevant features - creates X and y data Args: data_dir: Path to directory containing csv files to read feature_config: ml4ir.config.features.Features object extracted from the feature config batch_size: int value specifying the size of the batch parse_tfrecord: whether to parse SequenceExamples into features logger: logging object Returns: tensorflow dataset """ # Generate parsing function parse_sequence_example_fn = make_parse_fn(feature_config=feature_config, max_num_records=max_num_records) # Get all tfrecord files in directory tfrecord_files = file_io.get_files_in_directory( data_dir, extension="" if use_part_files else ".tfrecord", prefix="part-" if use_part_files else "", ) # Parse the protobuf data to create a TFRecordDataset dataset = data.TFRecordDataset(tfrecord_files) if parse_tfrecord: dataset = dataset.map(parse_sequence_example_fn).apply( data.experimental.ignore_errors()) dataset = dataset.batch(batch_size, drop_remainder=True) if logger: logger.info( "Created TFRecordDataset from SequenceExample protobufs from {} files : {}" .format(len(tfrecord_files), str(tfrecord_files)[:50])) return dataset
def _input_fn(): shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False file_names = data.Dataset.list_files(files_name_pattern) dataset = data.TFRecordDataset(filenames=file_names) if shuffle: dataset = dataset.shuffle(buffer_size=2 * batch_size + 1) dataset = dataset.batch(batch_size) dataset = dataset.map(lambda tf_example: parse_tf_example(tf_example)) dataset = dataset.map(lambda features, target: (process_features(features), target)) dataset = dataset.repeat(num_epochs) iterator = dataset.make_one_shot_iterator() features, target = iterator.get_next() return features, target
def get_batch_data(self): ''' 获取 Batch size 数据 :return: 图像Tensor , 标签Tensor ''' print("数据集 : ",self.data) dataSet = data.TFRecordDataset(self.data) dataSet = dataSet.map(self.parse) dataSet = dataSet.map(lambda image,label:(self.total_image_norm(image,[self.image_w,self.image_h,3]),label)) dataSet = dataSet.repeat() if self.shuffle: dataSet = dataSet.shuffle(1000) dataSet = dataSet.batch(self.batch_size) iterator = dataSet.make_initializable_iterator() image_batch, label_batch = iterator.get_next() self.sess.run(tf.local_variables_initializer()) self.sess.run(iterator.initializer) label_batch = tf.reshape(label_batch,[-1,1]) return image_batch,label_batch
def tfrecods_input_fn(files_name_pattern, mode=tf.estimator.ModeKeys.EVAL, num_epochs=None, batch_size=200): shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False print("") print("* data input_fn:") print("================") print("Input file(s): {}".format(files_name_pattern)) print("Batch size: {}".format(batch_size)) print("Epoch Count: {}".format(num_epochs)) print("Mode: {}".format(mode)) print("Shuffle: {}".format(shuffle)) print("================") print("") file_names = tf.matching_files(files_name_pattern) dataset = data.TFRecordDataset(filenames=file_names) if shuffle: dataset = dataset.shuffle(buffer_size=2 * batch_size + 1) dataset = dataset.batch(batch_size) dataset = dataset.map(lambda tf_example: parse_tf_example(tf_example)) if PROCESS_FEATURES: dataset = dataset.map( lambda features, target: (process_features(features), target)) dataset = dataset.repeat(num_epochs) iterator = dataset.make_one_shot_iterator() features, target = iterator.get_next() return features, target
def dataset_input_fn(file_names_pattern, file_encoding='csv', mode=tf.estimator.ModeKeys.EVAL, skip_header_lines=0, num_epochs=1, batch_size=200, multi_threading=True): """An input function for training or evaluation. This uses the Dataset APIs. Args: file_names_pattern: [str] - file name or file name patterns from which to read the data. mode: tf.estimator.ModeKeys - either TRAIN or EVAL. Used to determine whether or not to randomize the order of data. file_encoding: type of the text files. Can be 'csv' or 'tfrecords' skip_header_lines: int set to non-zero in order to skip header lines in CSV files. num_epochs: int - how many times through to read the data. If None will loop through data indefinitely batch_size: int - first dimension size of the Tensors returned by input_fn multi_threading: boolean - indicator to use multi-threading or not Returns: A function () -> (features, indices) where features is a dictionary of Tensors, and indices is a single Tensor of label indices. """ shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False data_size = parameters.HYPER_PARAMS.train_size if mode == tf.estimator.ModeKeys.TRAIN else None num_threads = multiprocessing.cpu_count() if multi_threading else 1 buffer_size = 2 * batch_size + 1 print("") print("* data input_fn:") print("================") print("Mode: {}".format(mode)) print("Input file(s): {}".format(file_names_pattern)) print("Files encoding: {}".format(file_encoding)) print("Data size: {}".format(data_size)) print("Batch size: {}".format(batch_size)) print("Epoch Count: {}".format(num_epochs)) print("Thread Count: {}".format(num_threads)) print("Shuffle: {}".format(shuffle)) print("================") print("") file_names = tf.matching_files(file_names_pattern) if file_encoding == 'csv': dataset = data.TextLineDataset(filenames=file_names) dataset = dataset.skip(skip_header_lines) dataset = dataset.map(lambda csv_row: parse_csv(csv_row)) else: dataset = data.TFRecordDataset(filenames=file_names) dataset = dataset.map(lambda tf_example: parse_tf_example(tf_example), num_parallel_calls=num_threads) dataset = dataset.map(lambda features: get_features_target_tuple(features), num_parallel_calls=num_threads) dataset = dataset.map(lambda features, target: (process_features(features), target), num_parallel_calls=num_threads) if shuffle: dataset = dataset.shuffle(buffer_size) dataset = dataset.batch(batch_size) dataset = dataset.prefetch(buffer_size) dataset = dataset.repeat(num_epochs) iterator = dataset.make_one_shot_iterator() features, target = iterator.get_next() return features, target
def _input_fn(): shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False data_size = task.HYPER_PARAMS.train_size if mode == tf.estimator.ModeKeys.TRAIN else None num_threads = multiprocessing.cpu_count() if multi_threading else 1 buffer_size = 2 * batch_size + 1 print("") print("* data input_fn:") print("================") print("Mode: {}".format(mode)) print("Input file(s): {}".format(file_names_pattern)) print("Files encoding: {}".format(file_encoding)) print("Data size: {}".format(data_size)) print("Batch size: {}".format(batch_size)) print("Epoch count: {}".format(num_epochs)) print("Thread count: {}".format(num_threads)) print("Shuffle: {}".format(shuffle)) print("================") print("") file_names = tf.matching_files(file_names_pattern) if file_encoding == 'csv': # Parallely processes num_threads files at the time. # Also, since the batch function is called before the map function # we use a block lenght of 1 to output one batch from every input # file before moving to the next file dataset = (data.Dataset.from_tensor_slices(file_names).interleave( lambda x: data.TextLineDataset(x).skip(skip_header_lines). batch(batch_size).map(parse_csv, num_parallel_calls=num_threads), cycle_length=num_threads, block_length=1)) else: dataset = data.TFRecordDataset(filenames=file_names) dataset = dataset = dataset.batch(batch_size) dataset = dataset.map( lambda tf_examples: parse_tf_example(tf_examples), num_parallel_calls=num_threads) dataset = dataset.map( lambda features: get_features_target_tuple(features), num_parallel_calls=num_threads) dataset = dataset.map(lambda features, target: (process_features(features), target), num_parallel_calls=num_threads) if shuffle: dataset = dataset.shuffle(buffer_size) dataset = dataset.prefetch(buffer_size) dataset = dataset.repeat(num_epochs) iterator = dataset.make_one_shot_iterator() features, target = iterator.get_next() return features, target
def read( data_dir: str, feature_config: FeatureConfig, tfrecord_type: str, file_io: FileIO, max_sequence_size: int = 0, batch_size: int = 0, preprocessing_keys_to_fns: dict = {}, parse_tfrecord: bool = True, use_part_files: bool = False, logger: Logger = None, **kwargs ) -> data.TFRecordDataset: """ Extract features by reading and parsing TFRecord data and converting into a TFRecordDataset using the FeatureConfig Parameters ---------- data_dir : str path to the directory containing train, validation and test data feature_config : `FeatureConfig` object FeatureConfig object that defines the features to be loaded in the dataset and the preprocessing functions to be applied to each of them tfrecord_type : {"example", "sequence_example"} Type of the TFRecord protobuf message to be used for TFRecordDataset file_io : `FileIO` object file I/O handler objects for reading and writing data max_sequence_size : int, optional maximum number of sequence to be used with a single SequenceExample proto message The data will be appropriately padded or clipped to fit the max value specified batch_size : int, optional size of each data batch preprocessing_keys_to_fns : dict of (str, function), optional dictionary of function names mapped to function definitions that can now be used for preprocessing while loading the TFRecordDataset to create the RelevanceDataset object use_part_files : bool, optional load dataset from part files checked using "part-" prefix parse_tfrecord : bool, optional parse the TFRecord string from the dataset; returns strings as is otherwise logger : `Logger`, optional logging handler for status messages Returns ------- `TFRecordDataset` TFRecordDataset loaded from the `data_dir` specified using the FeatureConfig """ parse_fn = get_parse_fn( feature_config=feature_config, tfrecord_type=tfrecord_type, preprocessing_keys_to_fns=preprocessing_keys_to_fns, max_sequence_size=max_sequence_size, ) # Get all tfrecord files in directory tfrecord_files = file_io.get_files_in_directory( data_dir, extension="" if use_part_files else ".tfrecord", prefix="part-" if use_part_files else "", ) # Parse the protobuf data to create a TFRecordDataset dataset = data.TFRecordDataset(tfrecord_files) if parse_tfrecord: # Parallel calls set to AUTOTUNE: improved training performance by 40% with a classification model dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).apply( data.experimental.ignore_errors() ) # Create BatchedDataSet if batch_size: dataset = dataset.batch(batch_size, drop_remainder=True) if logger: logger.info( "Created TFRecordDataset from SequenceExample protobufs from {} files : {}".format( len(tfrecord_files), str(tfrecord_files)[:50] ) ) # We apply prefetch as it improved train/test/validation throughput by 30% in some real model training. dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) return dataset
def commonDataSet(data_type, tensor=None, record_bytes=None, filename=None, file_pattern=None, generator=None, output_types=None, end_point=None, step=1, output_shapes=None, args=None, compression_type=None, buffer_size=None, num_parallel_reads=None, header_bytes=None, footer_bytes=None, start_point=0, datasets=None): assert data_type in { "file_list", "TFrecord", "binary", "tensor_slices", "tensor", "generator", "range", "zip", "debug" } if data_type == "file_line": # 文件中的行 DataSet = data.TextLineDataset(filenames=filename, compression_type=compression_type, buffer_size=buffer_size, num_parallel_reads=num_parallel_reads) """ filenames 使用的字符串tensor 或者 包含多个文件名 compression_type 可选项 ""(不压缩) ZLIB GZIP 字符串 None buffer_size 可选, 整形, 表示要缓存额大小, 0 会导致根据亚索类型选择默认的缓冲值 None num_parallel_reads 可选, int 表示要病毒读取的文件数。 如果值大于1, None 则以交错顺序输出并行读取的文件记录, 如果设为None ,则被顺序读取 """ elif data_type == "file_list": DataSet = data.Dataset.list_files(file_pattern=file_pattern, shuffle=None, seed=None) """ file_pattern 使用的文件列表 或字符串 "/path/*.txt" shuffle 是否进行打乱操作 默认为打乱 seed 随机种子数 int """ elif data_type == "TFrecord": # 读取TFreacord 文件使用的 tfrecord 文件列表 DataSet = data.TFRecordDataset(filenames=filename, compression_type=compression_type, buffer_size=buffer_size, num_parallel_reads=num_parallel_reads) """ filenames 使用的字符串tensor 或者 包含多个文件名 compression_type 可选项 ""(不压缩) ZLIB GZIP 字符串 None buffer_size 可选, 整形, 表示要缓存额大小, 0 会导致根据亚索类型选择默认的缓冲值 None num_parallel_reads 可选, int 表示要并行读取的文件数。 如果值大于1, None 则以交错顺序输出并行读取的文件记录, 如果设为None ,则被顺序读取 """ elif data_type == "binary": # CIFAR10数据集使用的数据格式就是这种, 二进制文件 DataSet = data.FixedLengthRecordDataset( filenames=filename, record_bytes=record_bytes, header_bytes=header_bytes, footer_bytes=footer_bytes, buffer_size=buffer_size, compression_type=compression_type, num_parallel_reads=num_parallel_reads) """ filenames 使用的字符串tensor 或者 tf.data.Dataset中包含多个文件名 record_bytes tf.int64 数据类型 header_bytes 表示文件开头要跳过的字节数, 可选 footer_bytes 表示文件末尾要忽略的字节数 buffer_size 可选, 整形, 表示要缓存额大小, 0 会导致根据亚索类型选择默认的缓冲值 None compression_type 可选项 ""(不压缩) ZLIB GZIP 字符串 None num_parallel_reads 可选, int 表示要病毒读取的文件数。 如果值大于1, None 则以交错顺序输出并行读取的文件记录, 如果设为None ,则被顺序读取 """ elif data_type == "generator": DataSet = data.Dataset.from_generator( generator=generator, output_types=output_types, output_shapes=output_shapes, args=args, ) """ generator , iter迭代对象, 若args未指定,怎, generator不必带参数,否则,它必须于args 中的参数一样多 output_types 数据类型 output_shapes 尺寸 可选 args """ elif data_type == "range": DataSet = data.Dataset.range(start_point, end_point, step) elif data_type == "zip": # zip 操作, 对两个或者对个数据集进行合并操作 DataSet = data.Dataset.zip(datasets=datasets) """ dataset 必须是一个tuple (datasetA, datasetB) """ elif data_type == "tensor_slices": # 张量 作用于 from_tensor 的区别, from_tensor_slices的作用是可以将tensor进行切分 # 切分第一个维, 如 tensor 使用的是 np.random.uniform(size=(5,2)). 即 划分为 5行2列的数据形状。 5行代表5个样本 # 2列为每个样本的维。 DataSet = data.Dataset.from_tensor_slices(tensors=tensor) # 实际上 是将array作为一个tf.constants 保存到计算图中, 当array比较大的时候,导致计算图非常大, 给传输于保存带来不变。 # 此时使用placeholder 取代这里的array. 并使用initializable iterator, 只在需要时,将array传进去, # 这样可以避免大数组保存在图中 elif data_type == "debug": DataSet = data.Dataset(...) else: DataSet = data.Dataset.from_tensors(tensors=tensor) return DataSet