Ejemplo n.º 1
0
def main():
    args = parseArgs()
    env = lmdb.open(args.mnist, readonly=True)
    positives = 0
    negatives = 0
    with env.begin() as txn:
        cursor = txn.cursor()
        cursor.first()
        while cursor.next():
            _, value = cursor.item()
            datum = mnistpair_pb2.Datum()
            datum.ParseFromString(value)
            if datum.label == 0:
                positives += 1
            else:
                negatives += 1
    total = positives + negatives
    print("=================================")
    print("Statistics:")
    print("=================================")
    print("Positives     : " + str(positives))
    print("Negatives     : " + str(negatives))
    print("Total         : " + str(total))
    print("Positives (\%): " + str(positives / total))
    print("Negatives (\%): " + str(negatives / total))
    print("=================================")
Ejemplo n.º 2
0
def makeDatum(frame1, frame2, label):
    datum = mnistpair_pb2.Datum()
    datum.channels = 1
    datum.width = frame1.shape[0]
    datum.height = frame1.shape[1]
    datum.frames.extend([frame1.tobytes(), frame2.tobytes()])
    datum.label = label
    return datum
Ejemplo n.º 3
0
def getBatch(cursor, imageSize, batchSize):
    datum = mnistpair_pb2.Datum()
    index = 0
    frames1 = np.empty((0, imageSize * imageSize), int)
    frames2 = np.empty((0, imageSize * imageSize), int)
    labels = np.empty((0), int)
    while index < batchSize:
        _, value = cursor.item()
        datum.ParseFromString(value)
        frame1 = np.fromstring(datum.frames[0], dtype=np.uint8)
        frame2 = np.fromstring(datum.frames[1], dtype=np.uint8)
        frames1 = np.vstack((frames1, frame1))
        frames2 = np.vstack((frames2, frame2))
        labels = np.hstack((labels, datum.label))
        index = index + 1
        if not cursor.next():
            cursor.first()
    return cursor, frames1, frames2, labels
Ejemplo n.º 4
0
def main():
    args = parseArgs()
    env = lmdb.open(args.mnist, readonly=True)
    with env.begin() as txn:
        cursor = txn.cursor()
        cursor.first()
        index = 0
        while cursor.next():
            if index > args.count:
                break
            key, value = cursor.item()
            datum = mnistpair_pb2.Datum()
            datum.ParseFromString(value)
            frame1 = np.fromstring(datum.frames[0], dtype=np.uint64)
            frame1 = frame1.reshape(datum.channels, datum.height, datum.width)
            frame2 = np.fromstring(datum.frames[1], dtype=np.uint64)
            frame2 = frame2.reshape(datum.channels, datum.height, datum.width)
            label = datum.label
            show(np.squeeze(frame1), np.squeeze(frame2), key, label)
            index += 1
Ejemplo n.º 5
0
def getImageSize(cursor):
    cursor.first()
    _, value = cursor.item()
    datum = mnistpair_pb2.Datum()
    datum.ParseFromString(value)
    return datum.channels, datum.height, datum.width