Ejemplo n.º 1
0
def disk_image_batch_dataset(img_paths, batch_size, labels=None, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None,
                             map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1):
    """Disk image batch dataset.

    This function is suitable for jpg and png files

    img_paths: string list or 1-D tensor, each of which is an iamge path
    labels: label list/tuple_of_list or tensor/tuple_of_tensor, each of which is a corresponding label
    """
    if labels is None:
        dataset = tf.data.Dataset.from_tensor_slices(img_paths)
    elif isinstance(labels, tuple):
        dataset = tf.data.Dataset.from_tensor_slices((img_paths,) + tuple(labels))
    else:
        dataset = tf.data.Dataset.from_tensor_slices((img_paths, labels))

    def parse_func(path, *label):
        img = tf.read_file(path)
        img = tf.image.decode_png(img, 3)
        return (img,) + label

    if map_func:
        def map_func_(*args):
            return map_func(*parse_func(*args))
    else:
        map_func_ = parse_func

    # dataset = dataset.map(parse_func, num_parallel_calls=num_threads) is slower

    dataset = batch_dataset(dataset, batch_size, prefetch_batch, drop_remainder, filter,
                            map_func_, num_threads, shuffle, buffer_size, repeat)

    return dataset
Ejemplo n.º 2
0
def tfrecord_batch_dataset(tfrecord_files,
                           infos,
                           compression_type,
                           batch_size,
                           prefetch_batch=_N_CPU + 1,
                           drop_remainder=True,
                           filter=None,
                           map_func=None,
                           num_threads=_N_CPU,
                           shuffle=True,
                           buffer_size=4096,
                           repeat=-1):
    """Tfrecord batch dataset.

    infos:
        for example
        [{'name': 'img', 'decoder': tf.image.decode_png, 'decode_param': {}, 'shape': [112, 112, 1]},
         {'name': 'point', 'decoder': tf.decode_raw, 'decode_param': dict(out_type = tf.float32), 'shape':[136]}]
    """
    dataset = tf.data.TFRecordDataset(
        tfrecord_files,
        compression_type=compression_type,
        buffer_size=_DEFAULT_READER_BUFFER_SIZE_BYTES)

    features = {}
    for info in infos:
        features[info['name']] = tf.FixedLenFeature([], tf.string)

    def parse_func(serialized_example):
        example = tf.parse_single_example(serialized_example,
                                          features=features)

        feature_dict = {}
        for info in infos:
            name = info['name']
            decoder = info['decoder']
            decode_param = info['decode_param']
            shape = info['shape']

            feature = decoder(example[name], **decode_param)
            feature = tf.reshape(feature, shape)
            feature_dict[name] = feature

        return feature_dict

    dataset = dataset.map(parse_func, num_parallel_calls=num_threads)

    dataset = batch_dataset(dataset, batch_size, prefetch_batch,
                            drop_remainder, filter, map_func, num_threads,
                            shuffle, buffer_size, repeat)

    return dataset
Ejemplo n.º 3
0
def memory_data_batch_dataset(memory_data_dict,
                              batch_size,
                              prefetch_batch=_N_CPU + 1,
                              drop_remainder=True,
                              filter=None,
                              map_func=None,
                              num_threads=_N_CPU,
                              shuffle=True,
                              buffer_size=4096,
                              repeat=-1):
    """Memory data batch dataset.

    `memory_data_dict` example:
        {'img': img_ndarray, 'label': label_ndarray} or
        {'img': img_tftensor, 'label': label_tftensor}
        * The value of each item of `memory_data_dict` is in shape of (N, ...).
    """
    dataset = tf.data.Dataset.from_tensor_slices(memory_data_dict)
    dataset = batch_dataset(dataset, batch_size, prefetch_batch,
                            drop_remainder, filter, map_func, num_threads,
                            shuffle, buffer_size, repeat)
    return dataset