def disk_image_batch_dataset(img_paths, batch_size, labels=None, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None, map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1): """Disk image batch dataset. This function is suitable for jpg and png files img_paths: string list or 1-D tensor, each of which is an iamge path labels: label list/tuple_of_list or tensor/tuple_of_tensor, each of which is a corresponding label """ if labels is None: dataset = tf.data.Dataset.from_tensor_slices(img_paths) elif isinstance(labels, tuple): dataset = tf.data.Dataset.from_tensor_slices((img_paths,) + tuple(labels)) else: dataset = tf.data.Dataset.from_tensor_slices((img_paths, labels)) def parse_func(path, *label): img = tf.read_file(path) img = tf.image.decode_png(img, 3) return (img,) + label if map_func: def map_func_(*args): return map_func(*parse_func(*args)) else: map_func_ = parse_func # dataset = dataset.map(parse_func, num_parallel_calls=num_threads) is slower dataset = batch_dataset(dataset, batch_size, prefetch_batch, drop_remainder, filter, map_func_, num_threads, shuffle, buffer_size, repeat) return dataset
def tfrecord_batch_dataset(tfrecord_files, infos, compression_type, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None, map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1): """Tfrecord batch dataset. infos: for example [{'name': 'img', 'decoder': tf.image.decode_png, 'decode_param': {}, 'shape': [112, 112, 1]}, {'name': 'point', 'decoder': tf.decode_raw, 'decode_param': dict(out_type = tf.float32), 'shape':[136]}] """ dataset = tf.data.TFRecordDataset( tfrecord_files, compression_type=compression_type, buffer_size=_DEFAULT_READER_BUFFER_SIZE_BYTES) features = {} for info in infos: features[info['name']] = tf.FixedLenFeature([], tf.string) def parse_func(serialized_example): example = tf.parse_single_example(serialized_example, features=features) feature_dict = {} for info in infos: name = info['name'] decoder = info['decoder'] decode_param = info['decode_param'] shape = info['shape'] feature = decoder(example[name], **decode_param) feature = tf.reshape(feature, shape) feature_dict[name] = feature return feature_dict dataset = dataset.map(parse_func, num_parallel_calls=num_threads) dataset = batch_dataset(dataset, batch_size, prefetch_batch, drop_remainder, filter, map_func, num_threads, shuffle, buffer_size, repeat) return dataset
def memory_data_batch_dataset(memory_data_dict, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None, map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1): """Memory data batch dataset. `memory_data_dict` example: {'img': img_ndarray, 'label': label_ndarray} or {'img': img_tftensor, 'label': label_tftensor} * The value of each item of `memory_data_dict` is in shape of (N, ...). """ dataset = tf.data.Dataset.from_tensor_slices(memory_data_dict) dataset = batch_dataset(dataset, batch_size, prefetch_batch, drop_remainder, filter, map_func, num_threads, shuffle, buffer_size, repeat) return dataset