Пример #1
0
    def path_to_datum(self, path, label,
            image_sum = None):
        """
        Creates a Datum from a path and a label
        May also update image_sum, if computing mean

        Arguments:
        path -- path to the image (filesystem path or URL)
        label -- numeric label for this image's category

        Keyword arguments:
        image_sum -- numpy array that stores a running sum of added images
        """
        # prepend path with image_folder, if appropriate
        if not utils.is_url(path) and self.image_folder and not os.path.isabs(path):
            path = os.path.join(self.image_folder, path)

        image = utils.image.load_image(path)
        image = utils.image.resize_image(image,
                self.height, self.width,
                channels    = self.channels,
                resize_mode = self.resize_mode,
                )

        if self.compute_mean and image_sum is not None:
            image_sum += image

        if not self.encoding or self.encoding == 'none':
            # Transform to caffe's format requirements
            if image.ndim == 3:
                # Transpose to (channels, height, width)
                image = image.transpose((2,0,1))
                if image.shape[0] == 3:
                    # channel swap
                    # XXX see issue #59
                    image = image[[2,1,0],...]
            elif image.ndim == 2:
                # Add a channels axis
                image = image[np.newaxis,:,:]
            else:
                raise Exception('Image has unrecognized shape: "%s"' % image.shape)
            datum = caffe.io.array_to_datum(image, label)
        else:
            datum = caffe_pb2.Datum()
            if image.ndim == 3:
                datum.channels = image.shape[2]
            else:
                datum.channels = 1
            datum.height = image.shape[0]
            datum.width = image.shape[1]
            datum.label = label

            s = StringIO()
            if self.encoding == 'png':
                PIL.Image.fromarray(image).save(s, format='PNG')
            elif self.encoding == 'jpg':
                PIL.Image.fromarray(image).save(s, format='JPEG', quality=90)
            datum.data = s.getvalue()
            datum.encoded = True
        return datum
Пример #2
0
def key_value_example(use_caffe_datum=False):
    lmdb_dir_path = './mylmdb'
    with lmdb.open(lmdb_dir_path, readonly=True) as env:
        with env.begin() as txn:
            cursor = txn.cursor()
            if use_caffe_datum:
                #from caffe.proto import caffe_pb2
                import caffe_pb2

                for k, v in cursor:
                    # REF [site] >> https://github.com/BVLC/caffe/blob/master/src/caffe/proto/caffe.proto
                    datum = caffe_pb2.Datum()
                    datum.ParseFromString(v)

                    x = np.fromstring(datum.data, dtype=np.uint8)
                    x = x.reshape(datum.channels, datum.height, datum.width)
                    y = datum.label
                    print(k.decode(), x.shape, y)
            else:
                for k, v in cursor:
                    datum = json.loads(v.decode('ascii'))
                    x = np.array(datum['data'], dtype=np.uint8)
                    x = x.reshape(datum['channels'], datum['height'],
                                  datum['width'])
                    y = datum['label']
                    print(k.decode(), x.shape, y)
Пример #3
0
def read_from_db_example(use_caffe_datum=False):
    leveldb_dir_path = './myleveldb'
    db = leveldb.LevelDB(leveldb_dir_path, create_if_missing=True)
    key = b'00000000'
    try:
        raw_datum = db.Get(key)
    except KeyError as ex:
        print('Invalid key, {}: {}.'.format(key, ex))
        return

    if use_caffe_datum:
        #from caffe.proto import caffe_pb2
        import caffe_pb2

        # REF [site] >> https://github.com/BVLC/caffe/blob/master/src/caffe/proto/caffe.proto
        datum = caffe_pb2.Datum()
        datum.ParseFromString(raw_datum)

        x = np.fromstring(datum.data, dtype=np.uint8)
        x = x.reshape(datum.channels, datum.height, datum.width)
        y = datum.label
    else:
        datum = json.loads(raw_datum.decode('ascii'))

        x = np.array(datum['data'], dtype=np.uint8)
        x = x.reshape(datum['channels'], datum['height'], datum['width'])
        y = datum['label']

    print(x.shape, y)
Пример #4
0
def parse_datum(value):
    """
    Parse a Caffe datum
    """
    datum = caffe_pb2.Datum()
    datum.ParseFromString(value)
    if datum.encoded:
        s = StringIO()
        s.write(datum.data)
        s.seek(0)
        img = PIL.Image.open(s)
        img = np.array(img)
    else:
        import caffe.io
        arr = caffe.io.datum_to_array(datum)
        # CHW -> HWC
        arr = arr.transpose((1, 2, 0))
        if arr.shape[2] == 1:
            # HWC -> HW
            arr = arr[:, :, 0]
        elif arr.shape[2] == 3:
            # BGR -> RGB
            # XXX see issue #59
            arr = arr[:, :, [2, 1, 0]]
        img = arr
    return img
Пример #5
0
 def array_to_datum(self, data, scalar_label, encoding):
     if data.ndim != 3:
         raise ValueError('Invalid number of dimensions: %d' % data.ndim)
     if encoding == 'none':
         if data.shape[0] == 3:
             # RGB to BGR
             # XXX see issue #59
             data = data[[2, 1, 0], ...]
         datum = caffe.io.array_to_datum(data, scalar_label)
     else:
         # Transpose to (height, width, channel)
         data = data.transpose((1, 2, 0))
         datum = caffe_pb2.Datum()
         datum.height = data.shape[0]
         datum.width = data.shape[1]
         datum.channels = data.shape[2]
         datum.label = scalar_label
         if data.shape[2] == 1:
             # grayscale
             data = data[:, :, 0]
         s = BytesIO()
         if encoding == 'png':
             PIL.Image.fromarray(data).save(s, format='PNG')
         elif encoding == 'jpg':
             PIL.Image.fromarray(data).save(s, format='JPEG', quality=90)
         else:
             raise ValueError('Invalid encoding type')
         datum.data = s.getvalue()
         datum.encoded = True
     return datum
Пример #6
0
    def __getitem__(self, index):
        """Returns the i-th example."""
        key = self.lmdb_cursor.key()
        if len(key) == 0:
            self.lmdb_cursor.first()

        datum = caffe_pb2.Datum()
        datum.ParseFromString(self.lmdb_cursor.value())
        self.lmdb_cursor.next()

        w = datum.width
        h = datum.height
        c = datum.channels
        y = datum.label

        xint8 = np.fromstring(datum.data, dtype=np.uint8).reshape(c, h, w)

        if self.random:
            top = random.randint(0, h - self.crop_size - 1)
            left = random.randint(0, w - self.crop_size - 1)
        else:
            top = (h - self.crop_size) // 2
            left = (w - self.crop_size) // 2

        bottom = top + self.crop_size
        right = left + self.crop_size
        xint8 = xint8[:, top:bottom, left:right]

        if self.flip:
            if random.randint(0, 1):
                xint8 = xint8[:, :, ::-1]

        xf32 = xint8 * np.float32(self.scale)
        return xf32, y
Пример #7
0
def read_from_db_example(use_caffe_datum=False):
    lmdb_dir_path = './mylmdb'
    with lmdb.open(lmdb_dir_path, readonly=True) as env:
        with env.begin() as txn:
            raw_datum = txn.get(b'00000000')

    if use_caffe_datum:
        #from caffe.proto import caffe_pb2
        import caffe_pb2

        # REF [site] >> https://github.com/BVLC/caffe/blob/master/src/caffe/proto/caffe.proto
        datum = caffe_pb2.Datum()
        datum.ParseFromString(raw_datum)

        x = np.fromstring(datum.data, dtype=np.uint8)
        x = x.reshape(datum.channels, datum.height, datum.width)
        y = datum.label
    else:
        datum = json.loads(raw_datum.decode('ascii'))

        x = np.array(datum['data'], dtype=np.uint8)
        x = x.reshape(datum['channels'], datum['height'], datum['width'])
        y = datum['label']

    print(x.shape, y)
Пример #8
0
def explore():
    """
    Returns a gallery consisting of the images of one of the dbs
    """
    job = job_from_request()
    # Get LMDB
    db = job.path(flask.request.args.get('db'))
    db_path = job.path(db)
    labels = []

    if COLOR_PALETTE_ATTRIBUTE in job.extension_userdata:
        # assume single-channel 8-bit palette
        palette = job.extension_userdata[COLOR_PALETTE_ATTRIBUTE]
        palette = np.array(palette).reshape((len(palette)/3,3)) / 255.
        # normalize input pixels to [0,1]
        norm = mpl.colors.Normalize(vmin=0,vmax=255)
        # create map
        cmap = mpl.pyplot.cm.ScalarMappable(norm=norm,
                                            cmap=mpl.colors.ListedColormap(palette))
    else:
        cmap = None

    page = int(flask.request.args.get('page', 0))
    size = int(flask.request.args.get('size', 25))

    reader = DbReader(db_path)
    count = 0
    imgs = []

    min_page = max(0, page - 5)
    total_entries = reader.total_entries

    max_page = min((total_entries-1) / size, page + 5)
    pages = range(min_page, max_page + 1)
    for key, value in reader.entries():
        if count >= page*size:
            datum = caffe_pb2.Datum()
            datum.ParseFromString(value)
            if not datum.encoded:
                raise RuntimeError("Expected encoded database")
            s = StringIO()
            s.write(datum.data)
            s.seek(0)
            img = PIL.Image.open(s)
            if cmap and img.mode in ['L', '1']:
                data = np.array(img)
                data = cmap.to_rgba(data)*255
                data = data.astype('uint8')
                # keep RGB values only, remove alpha channel
                data = data[:, :, 0:3]
                img = PIL.Image.fromarray(data)
            imgs.append({"label": None, "b64": utils.image.embed_image_html(img)})
        count += 1
        if len(imgs) >= size:
            break

    return flask.render_template('datasets/images/explore.html', page=page, size=size, job=job, imgs=imgs, labels=None, pages=pages, label=None, total_entries=total_entries, db=db)
Пример #9
0
def _array_to_label_datum(labels, single_label):
    if not single_label:
        label1, label2 = labels
        target = np.zeros((2, 1, 1))
        target[0, 0, 0] = label1
        target[1, 0, 0] = label2

        datum = caffe_pb2.Datum()
        datum.channels, datum.height, datum.width = 2, 1, 1
        datum.float_data.extend(target.flat)
    else:
        target = np.zeros((1, 1, 1))
        target[0, 0, 0] = labels

        datum = caffe_pb2.Datum()
        datum.channels, datum.height, datum.width = 1, 1, 1
        datum.float_data.extend(target.flat)
    return datum
Пример #10
0
def explore():
    """
    Returns a gallery consisting of the images of one of the dbs
    """
    job = job_from_request()
    # Get LMDB
    db = job.path(flask.request.args.get('db'))
    db_path = job.path(db)
    labels = []

    page = int(flask.request.args.get('page', 0))
    size = int(flask.request.args.get('size', 25))

    reader = DbReader(db_path)
    count = 0
    imgs = []

    min_page = max(0, page - 5)
    total_entries = reader.total_entries

    max_page = min((total_entries - 1) / size, page + 5)
    pages = range(min_page, max_page + 1)
    for key, value in reader.entries():
        if count >= page * size:
            datum = caffe_pb2.Datum()
            datum.ParseFromString(value)
            if not datum.encoded:
                raise RuntimeError("Expected encoded database")
            s = StringIO()
            s.write(datum.data)
            s.seek(0)
            img = PIL.Image.open(s)
            imgs.append({
                "label": None,
                "b64": utils.image.embed_image_html(img)
            })
        count += 1
        if len(imgs) >= size:
            break

    return flask.render_template('datasets/images/explore.html',
                                 page=page,
                                 size=size,
                                 job=job,
                                 imgs=imgs,
                                 labels=None,
                                 pages=pages,
                                 label=None,
                                 total_entries=total_entries,
                                 db=db)
Пример #11
0
def transform(prev):
    d = caffe_pb2.Datum()
    d.channels = 1
    d.height = 1
    d.width = 2064
    totals = [0] * d.width
    for entry, label in prev:
        totals = [t + e for t, e in zip(totals, entry)]
        d.data = struct.pack("2064d", *entry)
        d.label = label
        yield d.SerializeToString()

    d.data = struct.pack("2064d", *totals)
    open("watson_mean.binaryproto", "w").write(d.SerializeToString())
Пример #12
0
    def readFromDB(index, size, database):

        env = lmdb.open(database, readonly=True)
        #with env.begin() as txn:
        #	cursor = txn.cursor()
        #	for i in range (index , size):

        #		str_id = '{0:08}'.format(i)
        #		value = cursor.get(str_id.encode('ascii'))
        #		output = '{0:08}'.format(value)
        #		print(i, value)

        datum = caffe_pb2.Datum()

        with env.begin() as txn:
            cursor = txn.cursor()
            for key, value in cursor:
                datum.proto.caffe_pb2.ParseFromString(value)
                label = datum.label
                channels = datum.channels
                height = datum.height
                width = datum.width

                short = 0
                #first = True

                readArray = np.zeros((channels, height, width),
                                     dtype=np.uint16)

                print(channels)
                print(height)
                print(width)

                count = 0

                for c in datum.data:

                    readArray[(count / (height * height) % channels),
                              (count / height) % height,
                              count % height] = ord(struct.unpack('c', c)[0])
                    #print(struct.unpack('H',short)[0])
                    first = True
                    count = count + 1

                print(label,
                      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
                #lmdbReadWrite2.printArray(readArray[1,:,:], readArray[0,:,:])
                #print readArray[2,:,:]
                print readArray[0, :, :]
Пример #13
0
def array_to_datum(arr, label=0):
    """Converts a 3-dimensional array to datum. If the array has dtype uint8,
  the output data will be encoded as a string. Otherwise, the output data
  will be stored in float format.
  """
    if arr.ndim != 3:
        raise ValueError('Incorrect array shape.')
    datum = caffe_pb2.Datum()
    datum.channels, datum.height, datum.width = arr.shape
    if arr.dtype == np.uint8:
        datum.data = arr.tostring()
    else:
        datum.float_data.extend(arr.flat)
    datum.label = label
    return datum
Пример #14
0
def write_to_db_example(use_caffe_datum=False):
    N = 1000
    X = np.zeros((N, 3, 32, 32), dtype=np.uint8)
    y = np.zeros(N, dtype=np.int64)

    try:
        lmdb_dir_path = './mylmdb'
        map_size = X.nbytes * 10
        with lmdb.open(lmdb_dir_path, map_size=map_size) as env:
            with env.begin(write=True) as txn:  # A transaction object.
                if use_caffe_datum:
                    #from caffe.proto import caffe_pb2
                    import caffe_pb2

                    for i in range(N):
                        # REF [site] >> https://github.com/BVLC/caffe/blob/master/src/caffe/proto/caffe.proto
                        datum = caffe_pb2.Datum()
                        datum.channels = X.shape[1]
                        datum.height = X.shape[2]
                        datum.width = X.shape[3]
                        datum.data = X[i].tobytes(
                        )  # or .tostring() if numpy < 1.9.
                        datum.label = int(y[i])
                        str_id = '{:08}'.format(i)

                        # The encode is only essential in Python 3.
                        txn.put(str_id.encode('ascii'),
                                datum.SerializeToString())
                else:
                    for i in range(N):
                        datum = {
                            'channels': X.shape[1],
                            'height': X.shape[2],
                            'width': X.shape[3],
                            'data': X[i].tolist(),
                            'label': int(y[i]),
                        }
                        str_id = '{:08}'.format(i)

                        # The encode is only essential in Python 3.
                        txn.put(str_id.encode('ascii'),
                                json.dumps(datum).encode('ascii'))

            #--------------------
            print(env.stat())
    except lmdb.MapFullError as ex:
        print('lmdb.MapFullError raised: {}.'.format(ex))
Пример #15
0
def _array_to_datum(image, label, encoding):
    """
    Create a caffe Datum from a numpy.ndarray
    """
    if not encoding:
        # Transform to caffe's format requirements
        if image.ndim == 3:
            # Transpose to (channels, height, width)
            image = image.transpose((2, 0, 1))
            if image.shape[0] == 3:
                # channel swap
                # XXX see issue #59
                image = image[[2, 1, 0], ...]
        elif image.ndim == 2:
            # Add a channels axis
            image = image[np.newaxis, :, :]
        else:
            raise Exception('Image has unrecognized shape: "%s"' % image.shape)
        datum = caffe.io.array_to_datum(image, label)
    else:
        datum = caffe_pb2.Datum()
        if image.ndim == 3:
            datum.channels = image.shape[2]
        else:
            datum.channels = 1
        datum.height = image.shape[0]
        datum.width = image.shape[1]
        datum.label = label

        s = StringIO()
        if encoding == 'png':
            PIL.Image.fromarray(image).save(s.getvalue(), format='PNG')
        elif encoding == 'jpg':
            PIL.Image.fromarray(image).save(s.getvalue(),
                                            format='JPEG',
                                            quality=90)
        else:
            raise ValueError('Invalid encoding type')
        datum.data = s.getvalue()
        datum.encoded = True
    return datum
Пример #16
0
def write_to_db_example(use_caffe_datum=False):
    N = 1000
    X = np.zeros((N, 3, 32, 32), dtype=np.uint8)
    y = np.zeros(N, dtype=np.int64)

    leveldb_dir_path = './myleveldb'
    db = leveldb.LevelDB(leveldb_dir_path, create_if_missing=True)
    if use_caffe_datum:
        #from caffe.proto import caffe_pb2
        import caffe_pb2

        for i in range(N):
            # REF [site] >> https://github.com/BVLC/caffe/blob/master/src/caffe/proto/caffe.proto
            datum = caffe_pb2.Datum()
            datum.channels = X.shape[1]
            datum.height = X.shape[2]
            datum.width = X.shape[3]
            datum.data = X[i].tobytes()  # or .tostring() if numpy < 1.9.
            datum.label = int(y[i])
            str_id = '{:08}'.format(i)

            # The encode is only essential in Python 3.
            db.Put(str_id.encode('ascii'), datum.SerializeToString())
    else:
        for i in range(N):
            datum = {
                'channels': X.shape[1],
                'height': X.shape[2],
                'width': X.shape[3],
                'data': X[i].tolist(),
                'label': int(y[i]),
            }
            str_id = '{:08}'.format(i)

            # The encode is only essential in Python 3.
            db.Put(str_id.encode('ascii'), json.dumps(datum).encode('ascii'))

    #db.Delete(b'00000000')

    #--------------------
    print(db.GetStats())
Пример #17
0
def key_value_example(use_caffe_datum=False):
    leveldb_dir_path = './myleveldb'
    db = leveldb.LevelDB(leveldb_dir_path, create_if_missing=True)
    if use_caffe_datum:
        #from caffe.proto import caffe_pb2
        import caffe_pb2

        for k, v in db.RangeIter():
            # REF [site] >> https://github.com/BVLC/caffe/blob/master/src/caffe/proto/caffe.proto
            datum = caffe_pb2.Datum()
            datum.ParseFromString(v)

            x = np.fromstring(datum.data, dtype=np.uint8)
            x = x.reshape(datum.channels, datum.height, datum.width)
            y = datum.label
            print(k.decode(), x.shape, y)
    else:
        for k, v in db.RangeIter():
            datum = json.loads(v.decode('ascii'))
            x = np.array(datum['data'], dtype=np.uint8)
            x = x.reshape(datum['channels'], datum['height'], datum['width'])
            y = datum['label']
            print(k.decode(), x.shape, y)
Пример #18
0
def analyze_db(database,
               only_count=False,
               force_same_shape=False,
               print_data=False,
               ):
    """
    Looks at the data in a prebuilt database and verifies it
        Also prints out some information about it
    Returns True if all entries are valid

    Arguments:
    database -- path to the database

    Keyword arguments:
    only_count -- only count the entries, don't inspect them
    force_same_shape -- throw an error if not all images have the same shape
    print_data -- print the array for each datum
    """
    start_time = time.time()

    # Open database
    try:
        database = validate_database_path(database)
    except ValueError as e:
        logger.error(e.message)
        return False

    reader = DbReader(database)
    logger.info('Total entries: %s' % reader.total_entries)

    unique_shapes = Counter()

    count = 0
    update_time = None
    for key, value in reader.entries():
        datum = caffe_pb2.Datum()
        datum.ParseFromString(value)

        if print_data:
            array = caffe.io.datum_to_array(datum)
            print(('>>> Datum #%d (shape=%s)' % (count, array.shape)))
            print (array)

        if (not datum.HasField('height') or datum.height == 0 or
                not datum.HasField('width') or datum.width == 0):
            if datum.encoded:
                if force_same_shape or not len(list(unique_shapes.keys())):
                    # Decode datum to learn the shape
                    s = StringIO()
                    s.write(datum.data)
                    s.seek(0)
                    img = PIL.Image.open(s)
                    width, height = img.size
                    channels = len(img.split())
                else:
                    # We've already decoded one image, don't bother reading the rest
                    width = '?'
                    height = '?'
                    channels = '?'
            else:
                errstr = 'Shape is not set and datum is not encoded'
                logger.error(errstr)
                raise ValueError(errstr)
        else:
            width, height, channels = datum.width, datum.height, datum.channels

        shape = '%sx%sx%s' % (width, height, channels)

        unique_shapes[shape] += 1

        if force_same_shape and len(list(unique_shapes.keys())) > 1:
            logger.error("Images with different shapes found: %s and %s" % tuple(unique_shapes.keys()))
            return False

        count += 1
        # Send update every 2 seconds
        if update_time is None or (time.time() - update_time) > 2:
            logger.debug('>>> Key %s' % key)
            print_datum(datum)
            logger.debug('Progress: %s/%s' % (count, reader.total_entries))
            update_time = time.time()

        if only_count:
            # quit after reading one
            count = reader.total_entries
            logger.info('Assuming all entries have same shape ...')
            unique_shapes[list(unique_shapes.keys())[0]] = count
            break

    if count != reader.total_entries:
        logger.warning('LMDB reported %s total entries, but only read %s' % (reader.total_entries, count))

    for key, val in sorted(list(unique_shapes.items()), key=operator.itemgetter(1), reverse=True):
        logger.info('%s entries found with shape %s (WxHxC)' % (val, key))

    logger.info('Completed in %s seconds.' % (time.time() - start_time,))
    return True
Пример #19
0
def image_classification_dataset_explore():
    """
    Returns a gallery consisting of the images of one of the dbs
    """
    job = job_from_request()
    # Get LMDB
    db = flask.request.args.get('db', 'train')
    if 'train' in db.lower():
        task = job.train_db_task()
    elif 'val' in db.lower():
        task = job.val_db_task()
    elif 'test' in db.lower():
        task = job.test_db_task()
    if task == None:
        raise ValueError('No create_db task for {0}'.format(db))
    if task.status != 'D':
        raise ValueError(
            "This create_db task's status should be 'D' but is '{0}'".format(
                task.status))
    if task.backend != 'lmdb':
        raise ValueError(
            "Backend is {0} while expected backend is lmdb".format(
                task.backend))
    db_path = job.path(task.db_name)

    category_labels = job.get_labels()
    labels = {}
    for k, gt_category_type in enumerate(category_labels):
        fname = db_path[:-3] + '_' + str(k) + '.h5'
        labels[gt_category_type] = h5py.File(fname, 'r')

    page = int(flask.request.args.get('page', 0))
    size = int(flask.request.args.get('size', 25))
    label = flask.request.args.get('label', None)
    category_type = flask.request.args.get('type', None)

    if label == 'None' and category_type == 'None':
        label = None
        category_type = None

    if label is not None:
        label = int(label)
        #try:
        #    label = int(label)
        #except ValueError:
        #    label = None

    reader = DbReader(db_path)
    count = 0
    imgs = []

    min_page = max(0, page - 5)
    if label is None:
        total_entries = reader.total_entries
    else:
        index = category_labels.keys().index(category_type)
        total_entries = task.distribution[index][label]

    max_page = min((total_entries - 1) / size, page + 5)
    pages = range(min_page, max_page + 1)
    for key, value in reader.entries():
        if count >= page * size:
            datum = caffe_pb2.Datum()
            datum.ParseFromString(value)
            if category_type is not None:
                gt_label = int(labels[category_type]['label'][int(key)])
            else:
                gt_label = [
                    int(labels[c]['label'][int(key)]) for c in category_labels
                ]
            if label is None or gt_label == label:
                if datum.encoded:
                    s = StringIO()
                    s.write(datum.data)
                    s.seek(0)
                    img = PIL.Image.open(s)
                else:
                    import caffe.io
                    arr = caffe.io.datum_to_array(datum)
                    # CHW -> HWC
                    arr = arr.transpose((1, 2, 0))
                    if arr.shape[2] == 1:
                        # HWC -> HW
                        arr = arr[:, :, 0]
                    elif arr.shape[2] == 3:
                        # BGR -> RGB
                        # XXX see issue #59
                        arr = arr[:, :, [2, 1, 0]]
                    img = PIL.Image.fromarray(arr)
                if category_type is not None:
                    gt_class = category_labels[category_type][gt_label]
                else:
                    gt_class = [
                        category_labels[c][gt_label[k]]
                        for k, c in enumerate(category_labels)
                    ]
                    gt_class = ', '.join(gt_class)
                    #gt_class = category_labels[category_labels.keys()[0]][gt_label[0]]
                imgs.append({
                    "label": gt_class,
                    "b64": utils.image.embed_image_html(img)
                })
        if label is None:
            count += 1
        else:
            gt_label = int(labels[category_type]['label'][int(key)])
            if gt_label == int(label):
                count += 1
        if len(imgs) >= size:
            break

    return flask.render_template('datasets/images/classification/explore.html',
                                 page=page,
                                 size=size,
                                 job=job,
                                 imgs=imgs,
                                 category_labels=category_labels,
                                 pages=pages,
                                 label=label,
                                 type=category_type,
                                 total_entries=total_entries,
                                 db=db)
Пример #20
0
        image[t,:] = datum.float_data[ datum.width*(t): datum.width*(t+1)]
    return image

if len(sys.argv)!=3:
    print "usage: python make_pngs.py [lmdb folder/db] [output dir]"
    print "Will dump out a bunch of RGB pngs whose name will be the key."
    sys.exit(-1)

lmdb_dir = sys.argv[1]
outdir   = sys.argv[2]

lmdb_env = lmdb.open( lmdb_dir )
lmdb_txn = lmdb_env.begin()
lmdb_cursor = lmdb_txn.cursor()

datum = caffe.Datum()
for key, raw_datum in lmdb_cursor:
    datum.ParseFromString(raw_datum)
    label = datum.label
    data = datum_to_array( datum )
    rescaled = (255.0 / data.max() * (data - data.min())).astype(np.uint8)
    name = '%s/%s.png'%(outdir,key)
    print "Make ",name

    # RGB using unscaled ADC values
    matplotlib.image.imsave(name, data)

    # greyscale using scaled ADC values
    #im = Image.fromarray(rescaled)
    #im.save(name)
    
Пример #21
0
def image_classification_dataset_explore():
    """
    Returns a gallery consisting of the images of one of the dbs
    """
    job = job_from_request()
    # Get LMDB
    db = flask.request.args.get('db', 'train')
    if 'train' in db.lower():
        task = job.train_db_task()
    elif 'val' in db.lower():
        task = job.val_db_task()
    elif 'test' in db.lower():
        task = job.test_db_task()
    if task is None:
        raise ValueError('No create_db task for {0}'.format(db))
    if task.status != 'D':
        raise ValueError("This create_db task's status should be 'D' but is '{0}'".format(task.status))
    if task.backend != 'lmdb':
        raise ValueError("Backend is {0} while expected backend is lmdb".format(task.backend))
    db_path = job.path(task.db_name)
    labels = task.get_labels()

    page = int(flask.request.args.get('page', 0))
    size = int(flask.request.args.get('size', 25))
    label = flask.request.args.get('label', None)

    if label is not None:
        try:
            label = int(label)
            label_str = labels[label]
        except ValueError:
            label = None

    reader = DbReader(db_path)
    count = 0
    imgs = []

    min_page = max(0, page - 5)
    if label is None:
        total_entries = reader.total_entries
    else:
        total_entries = task.distribution[str(label)]

    max_page = min((total_entries-1) / size, page + 5)
    pages = range(min_page, max_page + 1)
    for key, value in reader.entries():
        if count >= page*size:
            datum = caffe_pb2.Datum()
            datum.ParseFromString(value)
            if label is None or datum.label == label:
                if datum.encoded:
                    s = StringIO()
                    s.write(datum.data)
                    s.seek(0)
                    img = PIL.Image.open(s)
                else:
                    import caffe.io
                    arr = caffe.io.datum_to_array(datum)
                    # CHW -> HWC
                    arr = arr.transpose((1,2,0))
                    if arr.shape[2] == 1:
                        # HWC -> HW
                        arr = arr[:,:,0]
                    elif arr.shape[2] == 3:
                        # BGR -> RGB
                        # XXX see issue #59
                        arr = arr[:,:,[2,1,0]]
                    img = PIL.Image.fromarray(arr)
                imgs.append({"label":labels[datum.label], "b64": utils.image.embed_image_html(img)})
        if label is None:
            count += 1
        else:
            datum = caffe_pb2.Datum()
            datum.ParseFromString(value)
            if datum.label == int(label):
                count += 1
        if len(imgs) >= size:
            break

    return flask.render_template('datasets/images/classification/explore.html', page=page, size=size, job=job, imgs=imgs, labels=labels, pages=pages, label=label, total_entries=total_entries, db=db)
Пример #22
0
def create_lmdbs(folder, file_list, image_count=None, db_batch_size=None):
    """
    Creates LMDBs for generic inference
    Returns the filename for a test image

    Creates these files in "folder":
        train_images/
        train_labels/
        val_images/
        val_labels/
        mean.binaryproto
        test.png
    """

    if image_count is None:
        train_image_count = TRAIN_IMAGE_COUNT
    else:
        train_image_count = image_count
    val_image_count = VAL_IMAGE_COUNT

    if db_batch_size is None:
        db_batch_size = DB_BATCH_SIZE

    # read file list
    images = []
    f = open(file_list)
    for line in f.readlines():
        line = line.strip()
        if not line:
            continue

        path = None
        # might contain a numerical label at the end
        match = re.match(r'(.*\S)\s+(\d+)$', line)
        if match:
            path = match.group(1)
            ground_truth = int(match.group(2))
            images.append([path, ground_truth])

    print(("Found %d image paths in image list" % len(images)))

    for phase, image_count in [
            ('train', train_image_count),
            ('val', val_image_count)]:

        print(("Will create %d pairs of %s images" % (image_count, phase)))

        # create DBs
        image_db = lmdb.open(os.path.join(folder, '%s_images' % phase),
                             map_async=True, max_dbs=0)
        label_db = lmdb.open(os.path.join(folder, '%s_labels' % phase),
                             map_async=True, max_dbs=0)

        # add up all images to later create mean image
        image_sum = None
        shape = None

        # save test images (one for each label)
        testImagesSameClass = []
        testImagesDifferentClass = []

        # arrays for image and label batch writing
        image_batch = []
        label_batch = []

        for i in range(image_count):
            # pick up random indices from image list
            index1 = random.randint(0, len(images) - 1)
            index2 = random.randint(0, len(images) - 1)
            # label=1 if images are from the same class otherwise label=0
            label = 1 if int(images[index1][1]) == int(images[index2][1]) else 0
            # load images from files
            image1 = np.array(utils.image.load_image(images[index1][0]))
            image2 = np.array(utils.image.load_image(images[index2][0]))
            if not shape:
                # initialize image sum for mean image
                shape = image1.shape
                image_sum = np.zeros((3, shape[0], shape[1]), 'float64')
            assert(image1.shape == shape and image2.shape == shape)

            # create BGR image: blue channel will contain first image,
            # green channel will contain second image
            image_pair = np.zeros(image_sum.shape)
            image_pair[0] = image1
            image_pair[1] = image2

            image_sum += image_pair

            # save test images on first pass
            if label > 0 and len(testImagesSameClass) < TEST_IMAGE_COUNT:
                testImagesSameClass.append(image_pair)
            if label == 0 and len(testImagesDifferentClass) < TEST_IMAGE_COUNT:
                testImagesDifferentClass.append(image_pair)

            # encode into Datum object
            image = image_pair.astype('uint8')
            datum = caffe.io.array_to_datum(image, -1)
            image_batch.append([str(i), datum])

            # create label Datum
            label_datum = caffe_pb2.Datum()
            label_datum.channels, label_datum.height, label_datum.width = 1, 1, 1
            label_datum.float_data.extend(np.array([label]).flat)
            label_batch.append([str(i), label_datum])

            if (i % db_batch_size == (db_batch_size - 1)) or (i == image_count - 1):
                _write_batch_to_lmdb(image_db, image_batch)
                _write_batch_to_lmdb(label_db, label_batch)
                image_batch = []
                label_batch = []

            if i % (image_count / 20) == 0:
                print("%d/%d" % (i, image_count))

        # close databases
        image_db.close()
        label_db.close()

        # save mean
        mean_image = (image_sum / image_count).astype('uint8')
        _save_mean(mean_image, os.path.join(folder, '%s_mean.binaryproto' % phase))
        _save_mean(mean_image, os.path.join(folder, '%s_mean.png' % phase))

        # create test images
        for idx, image in enumerate(testImagesSameClass):
            _save_image(image, os.path.join(folder, '%s_test_same_class_%d.png' % (phase, idx)))
        for idx, image in enumerate(testImagesDifferentClass):
            _save_image(image, os.path.join(folder, '%s_test_different_class_%d.png' % (phase, idx)))

    return
Пример #23
0
def infer(input_list,
          output_dir,
          jobs_dir,
          model_id,
          epoch,
          batch_size,
          layers,
          gpu,
          input_is_db,
          resize):
    """
    Perform inference on a list of images using the specified model
    """
    # job directory defaults to that defined in DIGITS config
    if jobs_dir == 'none':
        jobs_dir = digits.config.config_value('jobs_dir')

    # load model job
    model_dir = os.path.join(jobs_dir, model_id)
    assert os.path.isdir(model_dir), "Model dir %s does not exist" % model_dir
    model = Job.load(model_dir)

    # load dataset job
    dataset_dir = os.path.join(jobs_dir, model.dataset_id)
    assert os.path.isdir(dataset_dir), "Dataset dir %s does not exist" % dataset_dir
    dataset = Job.load(dataset_dir)
    for task in model.tasks:
        task.dataset = dataset

    # retrieve snapshot file
    task = model.train_task()
    snapshot_filename = None
    epoch = float(epoch)
    if epoch == -1 and len(task.snapshots):
        # use last epoch
        epoch = task.snapshots[-1][1]
        snapshot_filename = task.snapshots[-1][0]
    else:
        for f, e in task.snapshots:
            if e == epoch:
                snapshot_filename = f
                break
    if not snapshot_filename:
        raise InferenceError("Unable to find snapshot for epoch=%s" % repr(epoch))
    
    # Set color dataset
    kwargs = {'colormap': 'dataset'}
    vis = Visualization(dataset, **kwargs)
    
    # Delete existing png segmented images
    for filename in glob.glob("/home/scania/Scania/Agneev/Tmp/*"):
        os.remove(filename) 

    # retrieve image dimensions and resize mode
    image_dims = dataset.get_feature_dims()
    height = image_dims[0]
    width = image_dims[1]
    channels = image_dims[2]
    resize_mode = dataset.resize_mode if hasattr(dataset, 'resize_mode') else 'squash'

    n_input_samples = 0  # number of samples we were able to load
    input_ids = []       # indices of samples within file list
    input_data = []      # sample data
    input_filename = []

    if input_is_db:
        # load images from database
        reader = DbReader(input_list)
        for key, value in reader.entries():
            datum = caffe_pb2.Datum()
            datum.ParseFromString(value)
            if datum.encoded:
                s = StringIO()
                s.write(datum.data)
                s.seek(0)
                img = PIL.Image.open(s)
                img = np.array(img)
            else:
                import caffe.io
                arr = caffe.io.datum_to_array(datum)
                # CHW -> HWC
                arr = arr.transpose((1, 2, 0))
                if arr.shape[2] == 1:
                    # HWC -> HW
                    arr = arr[:, :, 0]
                elif arr.shape[2] == 3:
                    # BGR -> RGB
                    # XXX see issue #59
                    arr = arr[:, :, [2, 1, 0]]
                img = arr
            input_ids.append(key)
            input_data.append(img)
            n_input_samples = n_input_samples + 1
    else:
        # load paths from file
        paths = None
        try:
            if input_list.endswith('.h264') or input_list.endswith('.raw'):
                logging.info('Reading video...')
                ## http://stackoverflow.com/questions/33650974/opencv-python-read-specific-frame-using-videocapture
                cap = cv2.VideoCapture(input_list) #'/home/scania/Scania/Glantan_Recordings/2017-03-24_DrivePX2/dw_20170324_115921_0.000000_0.000000/video_front.h264')
                print cap
                frame_no = 0
                while frame_no < sys.maxint:
                    cap.set(1,frame_no);
                    ret, cv2_im = cap.read()
                    #if not ret:
                    #    break
                    cv2_im = cv2.cvtColor(cv2_im,cv2.COLOR_BGR2RGB)
                    image = PIL.Image.fromarray(cv2_im)
                    # print image
                    if resize:
                        image = utils.image.resize_image(
                            image,
                            height,
                            width,
                            channels=channels,
                            resize_mode=resize_mode)
                    else:
                        image = utils.image.image_to_array(
                            image,
                            channels=channels)
                    # single image inference
                    outputs, visualizations = model.train_task().infer_one(
                        image,
                        snapshot_epoch=epoch,
                        layers=layers,
                        gpu=gpu,
                        resize=resize)

                    out = dict([outputs.items()][0])
                    out['score'] = out.items()[0][1][0]
                    vis.process_data(n_input_samples, image, out, 'Video_file')
                    n_input_samples = n_input_samples + 1
                    frame_no = frame_no + 30

            elif input_list.endswith('.txt'):
    
                logging.info('Reading images...')
                with open(input_list) as infile:
                    paths = infile.readlines()
                # load and resize images
                for idx, path in enumerate(paths):
                    path = path.strip()
                    try:
                        image = utils.image.load_image(path.strip())
                        if resize:
                            image = utils.image.resize_image(
                                image,
                                height,
                                width,
                                channels=channels,
                                resize_mode=resize_mode)
                        else:
                            image = utils.image.image_to_array(
                                image,
                                channels=channels)

                        # single image inference
                        outputs, visualizations = model.train_task().infer_one(
                            image,
                            snapshot_epoch=epoch,
                            layers=layers,
                            gpu=gpu,
                            resize=resize)

                        # Find filename
                        head, tail = os.path.split(path)
                        filename = tail.split('.')[0]
                        out = dict([outputs.items()][0])
                        out['score'] = out.items()[0][1][0]
                        vis.process_data(n_input_samples, image, out, filename)
                        n_input_samples = n_input_samples + 1

                    except utils.errors.LoadImageError as e:
                        print e
            else:
                print 'Cannot read image or video file. \nPlease provide .h264, .raw or .txt file only.'
        except cv2.error as e:
            print e
Пример #24
0
def create_lmdbs(folder, image_width=None, image_height=None, image_count=None):
    """
    Creates LMDBs for generic inference
    Returns the filename for a test image

    Creates these files in "folder":
        train_images/
        train_labels/
        val_images/
        val_labels/
        mean.binaryproto
        test.png
    """
    if image_width is None:
        image_width = IMAGE_SIZE
    if image_height is None:
        image_height = IMAGE_SIZE

    if image_count is None:
        train_image_count = TRAIN_IMAGE_COUNT
    else:
        train_image_count = image_count
    val_image_count = VAL_IMAGE_COUNT

    # Used to calculate the gradients later
    yy, xx = np.mgrid[:image_height, :image_width].astype('float')

    for phase, image_count in [
            ('train', train_image_count),
            ('val', val_image_count)]:
        image_db = lmdb.open(os.path.join(folder, '%s_images' % phase),
                map_async=True,
                max_dbs=0)
        label_db = lmdb.open(os.path.join(folder, '%s_labels' % phase),
                map_async=True,
                max_dbs=0)

        image_sum = np.zeros((image_height, image_width), 'float64')

        for i in xrange(image_count):
            xslope, yslope = np.random.random_sample(2) - 0.5
            a = xslope * 255 / image_width
            b = yslope * 255 / image_height
            image = a * (xx - image_width/2) + b * (yy - image_height/2) + 127.5

            image_sum += image
            image = image.astype('uint8')

            pil_img = PIL.Image.fromarray(image)
            #pil_img.save(os.path.join(folder, '%s_%d.png' % (phase, i)))

            # create image Datum
            image_datum = caffe_pb2.Datum()
            image_datum.height = image.shape[0]
            image_datum.width = image.shape[1]
            image_datum.channels = 1
            s = StringIO()
            pil_img.save(s, format='PNG')
            image_datum.data = s.getvalue()
            image_datum.encoded = True
            _write_to_lmdb(image_db, str(i), image_datum.SerializeToString())

            # create label Datum
            label_datum = caffe_pb2.Datum()
            label_datum.channels, label_datum.height, label_datum.width = 1, 1, 2
            label_datum.float_data.extend(np.array([xslope, yslope]).flat)
            _write_to_lmdb(label_db, str(i), label_datum.SerializeToString())

        # close databases
        image_db.close()
        label_db.close()

        # save mean
        mean_image = (image_sum / image_count).astype('uint8')
        _save_mean(mean_image, os.path.join(folder, '%s_mean.png' % phase))
        _save_mean(mean_image, os.path.join(folder, '%s_mean.binaryproto' % phase))

    # create test image
    #   The network should be able to easily produce two numbers >1
    xslope, yslope = 0.5, 0.5
    a = xslope * 255 / image_width
    b = yslope * 255 / image_height
    test_image = a * (xx - image_width/2) + b * (yy - image_height/2) + 127.5
    test_image = test_image.astype('uint8')
    pil_img = PIL.Image.fromarray(test_image)
    test_image_filename = os.path.join(folder, 'test.png')
    pil_img.save(test_image_filename)

    return test_image_filename
Пример #25
0
def analyze_db(
    database,
    only_count=False,
    force_same_shape=False,
):
    """
    Looks at the data in a prebuilt database and verifies it
        Also prints out some information about it
    Returns True if all entries are valid

    Arguments:
    database -- path to the database

    Keyword arguments:
    only_count -- only count the entries, don't inspect them
    force_same_shape -- throw an error if not all images have the same shape
    """
    start_time = time.time()

    # Open database
    try:
        database = validate_database_path(database)
    except ValueError as e:
        logger.error(e.message)
        return False

    reader = DbReader(database)
    logger.info('Total entries: %s' % reader.total_entries)

    unique_shapes = Counter()

    count = 0
    update_time = None
    for key, value in reader.entries():
        datum = caffe_pb2.Datum()
        datum.ParseFromString(value)

        shape = '%sx%sx%s' % (datum.width, datum.height, datum.channels)
        unique_shapes[shape] += 1

        if force_same_shape and len(unique_shapes.keys()) > 1:
            logger.error("Images with different shapes found: %s and %s" %
                         tuple(unique_shapes.keys()))
            return False

        count += 1
        # Send update every 2 seconds
        if update_time is None or (time.time() - update_time) > 2:
            logger.debug('>>> Key %s' % key)
            print_datum(datum)
            logger.debug('Progress: %s/%s' % (count, reader.total_entries))
            update_time = time.time()

        if only_count:
            # quit after reading one
            count = reader.total_entries
            logger.info('Assuming all entries have same shape ...')
            unique_shapes[unique_shapes.keys()[0]] = count
            break

    if count != reader.total_entries:
        logger.warning('LMDB reported %s total entries, but only read %s' %
                       (reader.total_entries, count))

    for key, val in sorted(unique_shapes.items(),
                           key=operator.itemgetter(1),
                           reverse=True):
        logger.info('%s entries found with shape %s (WxHxC)' % (val, key))

    logger.info('Completed in %s seconds.' % (time.time() - start_time, ))
    return True
Пример #26
0
def infer(input_list, output_dir, jobs_dir, model_id, epoch, batch_size,
          layers, gpu, input_is_db, resize):
    """
    Perform inference on a list of images using the specified model
    """
    # job directory defaults to that defined in DIGITS config
    if jobs_dir == 'none':
        jobs_dir = digits.config.config_value('jobs_dir')

    # load model job
    model_dir = os.path.join(jobs_dir, model_id)
    assert os.path.isdir(model_dir), "Model dir %s does not exist" % model_dir
    model = Job.load(model_dir)

    # load dataset job
    dataset_dir = os.path.join(jobs_dir, model.dataset_id)
    assert os.path.isdir(
        dataset_dir), "Dataset dir %s does not exist" % dataset_dir
    dataset = Job.load(dataset_dir)
    for task in model.tasks:
        task.dataset = dataset

    # retrieve snapshot file
    task = model.train_task()
    snapshot_filename = None
    epoch = float(epoch)
    if epoch == -1 and len(task.snapshots):
        # use last epoch
        epoch = task.snapshots[-1][1]
        snapshot_filename = task.snapshots[-1][0]
    else:
        for f, e in task.snapshots:
            if e == epoch:
                snapshot_filename = f
                break
    if not snapshot_filename:
        raise InferenceError("Unable to find snapshot for epoch=%s" %
                             repr(epoch))

    # retrieve image dimensions and resize mode
    image_dims = dataset.get_feature_dims()
    height = image_dims[0]
    width = image_dims[1]
    channels = image_dims[2]
    resize_mode = dataset.resize_mode if hasattr(dataset,
                                                 'resize_mode') else 'squash'

    n_input_samples = 0  # number of samples we were able to load
    input_ids = []  # indices of samples within file list
    input_data = []  # sample data

    if input_is_db:
        # load images from database
        reader = DbReader(input_list)
        for key, value in reader.entries():
            datum = caffe_pb2.Datum()
            datum.ParseFromString(value)
            if datum.encoded:
                s = StringIO()
                s.write(datum.data)
                s.seek(0)
                img = PIL.Image.open(s)
                img = np.array(img)
            else:
                import caffe.io
                arr = caffe.io.datum_to_array(datum)
                # CHW -> HWC
                arr = arr.transpose((1, 2, 0))
                if arr.shape[2] == 1:
                    # HWC -> HW
                    arr = arr[:, :, 0]
                elif arr.shape[2] == 3:
                    # BGR -> RGB
                    # XXX see issue #59
                    arr = arr[:, :, [2, 1, 0]]
                img = arr
            input_ids.append(key)
            input_data.append(img)
            n_input_samples = n_input_samples + 1
    else:
        # load paths from file
        paths = None
        with open(input_list) as infile:
            paths = infile.readlines()
        # load and resize images
        for idx, path in enumerate(paths):
            path = path.strip()
            try:
                image = utils.image.load_image(path.strip())
                if resize:
                    image = utils.image.resize_image(image,
                                                     height,
                                                     width,
                                                     channels=channels,
                                                     resize_mode=resize_mode)
                else:
                    image = utils.image.image_to_array(image,
                                                       channels=channels)
                input_ids.append(idx)
                input_data.append(image)
                n_input_samples = n_input_samples + 1
            except utils.errors.LoadImageError as e:
                print e

    # perform inference
    visualizations = None
    predictions = []

    if n_input_samples == 0:
        raise InferenceError("Unable to load any image from file '%s'" %
                             repr(input_list))
    elif n_input_samples == 1:
        # single image inference
        outputs, visualizations = model.train_task().infer_one(
            input_data[0],
            snapshot_epoch=epoch,
            layers=layers,
            gpu=gpu,
            resize=resize)
    else:
        if layers != 'none':
            raise InferenceError(
                "Layer visualization is not supported for multiple inference")
        outputs = model.train_task().infer_many(input_data,
                                                snapshot_epoch=epoch,
                                                gpu=gpu,
                                                resize=resize)

    # write to hdf5 file
    db_path = os.path.join(output_dir, 'inference.hdf5')
    db = h5py.File(db_path, 'w')

    # write input paths and images to database
    db.create_dataset("input_ids", data=input_ids)
    db.create_dataset("input_data", data=input_data)

    # write outputs to database
    db_outputs = db.create_group("outputs")
    for output_id, output_name in enumerate(outputs.keys()):
        output_data = outputs[output_name]
        output_key = base64.urlsafe_b64encode(str(output_name))
        dset = db_outputs.create_dataset(output_key, data=output_data)
        # add ID attribute so outputs can be sorted in
        # the order they appear in here
        dset.attrs['id'] = output_id

    # write visualization data
    if visualizations is not None and len(visualizations) > 0:
        db_layers = db.create_group("layers")
        for idx, layer in enumerate(visualizations):
            vis = layer['vis'] if layer['vis'] is not None else np.empty(0)
            dset = db_layers.create_dataset(str(idx), data=vis)
            dset.attrs['name'] = layer['name']
            dset.attrs['vis_type'] = layer['vis_type']
            if 'param_count' in layer:
                dset.attrs['param_count'] = layer['param_count']
            if 'layer_type' in layer:
                dset.attrs['layer_type'] = layer['layer_type']
            dset.attrs['shape'] = layer['data_stats']['shape']
            dset.attrs['mean'] = layer['data_stats']['mean']
            dset.attrs['stddev'] = layer['data_stats']['stddev']
            dset.attrs['histogram_y'] = layer['data_stats']['histogram'][0]
            dset.attrs['histogram_x'] = layer['data_stats']['histogram'][1]
            dset.attrs['histogram_ticks'] = layer['data_stats']['histogram'][2]
    db.close()
    logger.info('Saved data to %s', db_path)
Пример #27
0
def infer(input_list, output_dir, jobs_dir, model_id, epoch, batch_size,
          layers, gpu, input_is_db, label_file, resize):
    """
    Perform inference on a list of images using the specified model
    """
    # job directory defaults to that defined in DIGITS config
    if jobs_dir == 'none':
        jobs_dir = digits.config.config_value('jobs_dir')

    # load model job
    model_dir = os.path.join(jobs_dir, model_id)
    assert os.path.isdir(model_dir), "Model dir %s does not exist" % model_dir
    model = Job.load(model_dir)

    # load dataset job
    dataset_dir = os.path.join(jobs_dir, model.dataset_id)
    assert os.path.isdir(
        dataset_dir), "Dataset dir %s does not exist" % dataset_dir
    dataset = Job.load(dataset_dir)
    for task in model.tasks:
        task.dataset = dataset

    # retrieve snapshot file
    task = model.train_task()
    snapshot_filename = None
    epoch = float(epoch)
    if epoch == -1 and len(task.snapshots):
        # use last epoch
        epoch = task.snapshots[-1][1]
        snapshot_filename = task.snapshots[-1][0]
    else:
        for f, e in task.snapshots:
            if e == epoch:
                snapshot_filename = f
                break
    if not snapshot_filename:
        raise InferenceError("Unable to find snapshot for epoch=%s" %
                             repr(epoch))

    # retrieve image dimensions and resize mode
    image_dims = dataset.get_feature_dims()
    height = image_dims[0]
    width = image_dims[1]
    channels = image_dims[2]
    resize_mode = dataset.resize_mode if hasattr(dataset,
                                                 'resize_mode') else 'squash'

    n_input_samples = 0  # number of samples we were able to load
    input_ids = []  # indices of samples within file list
    input_data = []  # sample data

    if input_is_db:
        # load images from database
        reader = DbReader(input_list)
        for key, value in reader.entries():
            datum = caffe_pb2.Datum()
            datum.ParseFromString(value)
            if datum.encoded:
                s = StringIO()
                s.write(datum.data)
                s.seek(0)
                img = PIL.Image.open(s)
                img = np.array(img)
            else:
                import caffe.io
                arr = caffe.io.datum_to_array(datum)
                # CHW -> HWC
                arr = arr.transpose((1, 2, 0))
                if arr.shape[2] == 1:
                    # HWC -> HW
                    arr = arr[:, :, 0]
                elif arr.shape[2] == 3:
                    # BGR -> RGB
                    # XXX see issue #59
                    arr = arr[:, :, [2, 1, 0]]
                img = arr
            input_ids.append(key)
            input_data.append(img)
            n_input_samples = n_input_samples + 1
    else:
        # load paths from file
        paths = None
        with open(input_list) as infile:
            paths = infile.readlines()
        # load and resize images
        for idx, path in enumerate(paths):
            path = path.strip()
            try:
                image = utils.image.load_image(path.strip())
                if resize:
                    image = utils.image.resize_image(image,
                                                     height,
                                                     width,
                                                     channels=channels,
                                                     resize_mode=resize_mode)
                else:
                    image = utils.image.image_to_array(image,
                                                       channels=channels)
                input_ids.append(idx)
                input_data.append(image)
                n_input_samples = n_input_samples + 1
            except utils.errors.LoadImageError as e:
                print e

    labels = np.loadtxt(label_file, dtype='object')
    if fig is not None:
        # Plot original images to grid
        for row in range(NUM_ROWS):
            for col in range(NUM_COLS):
                idx = row * NUM_COLS + col
                pl.subplot(NUM_ROWS * 2, NUM_COLS,
                           row * 2 * NUM_COLS + col + 1)
                pl.xticks([])
                pl.yticks([])
                pl.imshow(input_data[idx], interpolation='nearest')

    # perform inference
    visualizations = None

    logger.info('Inference')
    if n_input_samples == 0:
        raise InferenceError("Unable to load any image from file '%s'" %
                             repr(input_list))
    elif n_input_samples == 1:
        # single image inference
        logger.info('Start')
        outputs, visualizations = model.train_task().infer_one(
            input_data[0],
            snapshot_epoch=epoch,
            layers=layers,
            gpu=gpu,
            resize=resize)
        logger.info('Done!')
    else:
        if layers != 'none':
            raise InferenceError(
                "Layer visualization is not supported for multiple inference")
        outputs = model.train_task().infer_many(input_data,
                                                snapshot_epoch=epoch,
                                                gpu=gpu,
                                                resize=resize)

    logger.info('Now it\'s time to pass results to write')

    # write to hdf5 file
    db_path = os.path.join(output_dir, 'inference.hdf5')
    db = h5py.File(db_path, 'w')

    # write input paths and images to database
    db.create_dataset("input_ids", data=input_ids)
    db.create_dataset("input_data", data=input_data)

    # write outputs to database
    db_outputs = db.create_group("outputs")
    for output_id, output_name in enumerate(outputs.keys()):
        output_data = outputs[output_name]
        if fig is not None:
            # Plot top-K inferences on grids
            for elem_id, elem_data in enumerate(output_data):
                row = elem_id // NUM_COLS
                col = elem_id % NUM_COLS
                img_labels = sorted(zip(elem_data, labels),
                                    key=lambda x: x[0])[-NUM_TOPK_CLASSES:]
                ax = pl.subplot(NUM_ROWS * 2,
                                NUM_COLS, (row * 2 + 1) * NUM_COLS + col + 1,
                                aspect='equal')
                ax.yaxis.set_label_position("right")
                ax.yaxis.set_label_coords(1.25, 0.5)
                pl.ylabel('Confidence score', rotation=-90, fontsize=16)

                height = 0.5
                ylocs = np.array(range(NUM_TOPK_CLASSES)) * height + 0.1
                width = max(ylocs)
                top_class = img_labels[-1][1]
                pl.barh(ylocs, [l[0]*width for l in img_labels], height=height, \
                        color=['r' if l[1] == top_class else 'b' for l in img_labels]) #color=['r' if l[1] == labels[true_label] else 'b' for l in img_labels])
                pl.yticks(ylocs + height / 2, [l[1] for l in img_labels],
                          fontsize=14)
                pl.xticks([0, width / 2.0, width], ['0%', '50%', '100%'])
                pl.ylim(0, ylocs[-1] + height + 0.1)
            pl.tight_layout()
            pl.show()
            fig.savefig('./test.pdf', dpi=300)
        output_key = base64.urlsafe_b64encode(str(output_name))
        dset = db_outputs.create_dataset(output_key, data=output_data)
        # add ID attribute so outputs can be sorted in
        # the order they appear in here
        dset.attrs['id'] = output_id

    # write visualization data
    if visualizations is not None and len(visualizations) > 0:
        db_layers = db.create_group("layers")
        for idx, layer in enumerate(visualizations):
            vis = layer['vis'] if layer['vis'] is not None else np.empty(0)
            dset = db_layers.create_dataset(str(idx), data=vis)
            dset.attrs['name'] = layer['name']
            dset.attrs['vis_type'] = layer['vis_type']
            if 'param_count' in layer:
                dset.attrs['param_count'] = layer['param_count']
            if 'layer_type' in layer:
                dset.attrs['layer_type'] = layer['layer_type']
            dset.attrs['shape'] = layer['data_stats']['shape']
            dset.attrs['mean'] = layer['data_stats']['mean']
            dset.attrs['stddev'] = layer['data_stats']['stddev']
            dset.attrs['histogram_y'] = layer['data_stats']['histogram'][0]
            dset.attrs['histogram_x'] = layer['data_stats']['histogram'][1]
            dset.attrs['histogram_ticks'] = layer['data_stats']['histogram'][2]
    db.close()
    logger.info('Saved data to %s', db_path)