示例#1
0
def read_and_decode(filename, one_hot=True, n_class=None, is_train=None):
    """ Return tensor to read from TFRecord """
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label':
                                           tf.FixedLenFeature([], tf.int64),
                                           'image_raw':
                                           tf.FixedLenFeature([], tf.string),
                                       })
    # You can do more image distortion here for training data
    img = tf.decode_raw(features['image_raw'], tf.uint8)
    img.set_shape([28 * 28])
    img = tf.reshape(img, [28, 28, 1])

    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    # img = tf.cast(img, tf.float32) * (1. / 255)

    label = tf.cast(features['label'], tf.int32)
    if one_hot and n_class:
        label = tf.one_hot(label, n_class)

    return img, label
示例#2
0
def read_and_decode(filename,
                    w,
                    h,
                    one_hot=True,
                    n_class=None,
                    is_train=None,
                    bResize=False,
                    origImgW=0,
                    origImgH=0):
    """ Return tensor to read from TFRecord """
    # files = tf.train.match_filenames_once(filename)
    files = filename
    # print(files)
    filename_queue = tf.train.string_input_producer(files)
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = \
        tf.parse_single_example(serialized_example,
                                features={
                                 'height': tf.FixedLenFeature([], tf.int64),
                                 'width': tf.FixedLenFeature([], tf.int64),
                                 'depth': tf.FixedLenFeature([], tf.int64),
                                 'image_raw': tf.FixedLenFeature([], tf.string),
                                 'label': tf.FixedLenFeature([], tf.int64)
                                })
    # You can do more image distortion here for training data
    img = tf.decode_raw(features['image_raw'], tf.uint8)
    img = tf.reshape(img, [origImgW, origImgH, 3])
    if bResize:
        img = tf.image.resize_images(img, (w, h), method=0)
    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    # img = tf.cast(img, tf.float32) * (1. / 255)

    label = features['label']

    # label = tf.cast(label, tf.float32)
    if one_hot and n_class:
        label = tf.one_hot(label, n_class)

    return img, label
示例#3
0
def binarize(x, sz=num_alphabet):
    from keras.backend import tf
    return tf.to_float(tf.one_hot(x, sz, on_value=1, off_value=0, axis=-1))