Exemplo n.º 1
0
 def _array_to_datum(self, data, scalar_label, encoding):
     if data.ndim != 3:
         raise ValueError('Invalid number of dimensions: %d' % data.ndim)
     if encoding == 'none':
         if data.shape[0] == 3:
             # RGB to BGR
             # XXX see issue #59
             data = data[[2, 1, 0], ...]
         datum = array_to_datum(data, scalar_label)
     else:
         # Transpose to (height, width, channel)
         data = data.transpose((1, 2, 0))
         datum = dataset_pb2.Datum()
         datum.height = data.shape[0]
         datum.width = data.shape[1]
         datum.channels = data.shape[2]
         datum.label = scalar_label
         if data.shape[2] == 1:
             # grayscale
             data = data[:, :, 0]
         s = BytesIO()
         if encoding == 'png':
             PIL.Image.fromarray(data).save(s, format='PNG')
         elif encoding == 'jpg':
             PIL.Image.fromarray(data).save(s, format='JPEG', quality=90)
         else:
             raise ValueError('Invalid encoding type')
         datum.data = s.getvalue()
         datum.encoded = True
     return datum
Exemplo n.º 2
0
def _array_to_datum(image, label, encoding):
    """
    Create a caffe Datum from a numpy.ndarray
    """
    if not encoding:
        # Transform to caffe's format requirements
        if image.ndim == 3:
            # Transpose to (channels, height, width)
            image = image.transpose((2, 0, 1))
            if image.shape[0] == 3:
                # channel swap
                # XXX see issue #59
                image = image[[2, 1, 0], ...]
        elif image.ndim == 2:
            # Add a channels axis
            image = image[np.newaxis, :, :]
        else:
            raise Exception('Image has unrecognized shape: "%s"' % image.shape)
        datum = array_to_datum(image, label)
    else:
        datum = dataset_pb2.Datum()
        if image.ndim == 3:
            datum.channels = image.shape[2]
        else:
            datum.channels = 1
        datum.height = image.shape[0]
        datum.width = image.shape[1]
        datum.label = label

        s = BytesIO()
        if encoding == 'png':
            PIL.Image.fromarray(image).save(s, format='PNG')
        elif encoding == 'jpg':
            PIL.Image.fromarray(image).save(s, format='JPEG', quality=90)
        else:
            raise ValueError('Invalid encoding type')
        datum.data = s.getvalue()
        datum.encoded = True
    return datum
Exemplo n.º 3
0
def explore():
    """
    Returns a gallery consisting of the images of one of the dbs
    """
    job = job_from_request()
    # Get LMDB
    db = flask.request.args.get('db', 'train')
    if 'train' in db.lower():
        task = job.train_db_task()
    elif 'val' in db.lower():
        task = job.val_db_task()
    elif 'test' in db.lower():
        task = job.test_db_task()
    if task is None:
        raise ValueError('No create_db task for {0}'.format(db))
    if task.status != 'D':
        raise ValueError(
            "This create_db task's status should be 'D' but is '{0}'".format(
                task.status))
    if task.backend != 'lmdb':
        raise ValueError(
            "Backend is {0} while expected backend is lmdb".format(
                task.backend))
    db_path = job.path(task.db_name)
    labels = task.get_labels()

    page = int(flask.request.args.get('page', 0))
    size = int(flask.request.args.get('size', 25))
    label = flask.request.args.get('label', None)

    if label is not None:
        try:
            label = int(label)
        except ValueError:
            label = None

    reader = DbReader(db_path)
    count = 0
    imgs = []

    min_page = max(0, page - 5)
    if label is None:
        total_entries = reader.total_entries
    else:
        # After PR#1500, task.distribution[str(label)] is a dictionary
        # with keys = 'count' and 'error_count'
        label_entries = task.distribution[str(label)]
        if isinstance(label_entries, dict):
            total_entries = label_entries['count']
        else:
            total_entries = label_entries

    max_page = min((total_entries - 1) // size, page + 5)
    pages = range(min_page, max_page + 1)
    for key, value in reader.entries():
        if count >= page * size:
            datum = dataset_pb2.Datum()
            datum.ParseFromString(value)
            if label is None or datum.label == label:
                if datum.encoded:
                    s = BytesIO()
                    s.write(datum.data)
                    s.seek(0)
                    img = PIL.Image.open(s)
                else:
                    arr = datum_to_array(datum)
                    # CHW -> HWC
                    arr = arr.transpose((1, 2, 0))
                    if arr.shape[2] == 1:
                        # HWC -> HW
                        arr = arr[:, :, 0]
                    elif arr.shape[2] == 3:
                        # BGR -> RGB
                        # XXX see issue #59
                        arr = arr[:, :, [2, 1, 0]]
                    img = PIL.Image.fromarray(arr)
                imgs.append({
                    "label": labels[datum.label],
                    "b64": utils.image.embed_image_html(img)
                })
        if label is None:
            count += 1
        else:
            datum = dataset_pb2.Datum()
            datum.ParseFromString(value)
            if datum.label == int(label):
                count += 1
        if len(imgs) >= size:
            break

    return flask.render_template('datasets/images/explore.html',
                                 page=page,
                                 size=size,
                                 job=job,
                                 imgs=imgs,
                                 labels=labels,
                                 pages=pages,
                                 label=label,
                                 total_entries=total_entries,
                                 db=db)
Exemplo n.º 4
0
def create_lmdbs(folder,
                 image_width=None,
                 image_height=None,
                 image_count=None):
    """
    Creates LMDBs for generic inference
    Returns the filename for a test image

    Creates these files in "folder":
        train_images/
        train_labels/
        val_images/
        val_labels/
        mean.binaryproto
        test.png
    """
    if image_width is None:
        image_width = IMAGE_SIZE
    if image_height is None:
        image_height = IMAGE_SIZE

    if image_count is None:
        train_image_count = TRAIN_IMAGE_COUNT
    else:
        train_image_count = image_count
    val_image_count = VAL_IMAGE_COUNT

    # Used to calculate the gradients later
    yy, xx = np.mgrid[:image_height, :image_width].astype('float')

    for phase, image_count in [('train', train_image_count),
                               ('val', val_image_count)]:
        image_db = lmdb.open(os.path.join(folder, '%s_images' % phase),
                             map_async=True,
                             max_dbs=0)
        label_db = lmdb.open(os.path.join(folder, '%s_labels' % phase),
                             map_async=True,
                             max_dbs=0)

        image_sum = np.zeros((image_height, image_width), 'float64')

        for i in range(image_count):
            xslope, yslope = np.random.random_sample(2) - 0.5
            a = xslope * 255 / image_width
            b = yslope * 255 / image_height
            image = a * (xx - image_width / 2) + b * (yy -
                                                      image_height / 2) + 127.5

            image_sum += image
            image = image.astype('uint8')

            pil_img = PIL.Image.fromarray(image)

            # create image Datum
            image_datum = dataset_pb2.Datum()
            image_datum.height = image.shape[0]
            image_datum.width = image.shape[1]
            image_datum.channels = 1
            s = StringIO()
            pil_img.save(str(s), format='PNG')
            image_datum.data = s.getvalue().encode('utf-8')
            image_datum.encoded = True
            _write_to_lmdb(image_db, str(i), image_datum.SerializeToString())

            # create label Datum
            label_datum = dataset_pb2.Datum()
            label_datum.channels, label_datum.height, label_datum.width = 1, 1, 2
            label_datum.float_data.extend(np.array([xslope, yslope]).flat)
            _write_to_lmdb(label_db, str(i), label_datum.SerializeToString())

        # close databases
        image_db.close()
        label_db.close()

        # save mean
        mean_image = (image_sum / image_count).astype('uint8')
        _save_mean(mean_image, os.path.join(folder, '%s_mean.png' % phase))
        _save_mean(mean_image,
                   os.path.join(folder, '%s_mean.binaryproto' % phase))

    # create test image
    #   The network should be able to easily produce two numbers >1
    xslope, yslope = 0.5, 0.5
    a = xslope * 255 / image_width
    b = yslope * 255 / image_height
    test_image = a * (xx - image_width / 2) + b * (yy -
                                                   image_height / 2) + 127.5
    test_image = test_image.astype('uint8')
    pil_img = PIL.Image.fromarray(test_image)
    test_image_filename = os.path.join(folder, 'test.png')
    pil_img.save(test_image_filename)

    return test_image_filename
Exemplo n.º 5
0
def analyze_db(
    database,
    only_count=False,
    force_same_shape=False,
    print_data=False,
):
    """
    Looks at the data in a prebuilt database and verifies it
        Also prints out some information about it
    Returns True if all entries are valid

    Arguments:
    database -- path to the database

    Keyword arguments:
    only_count -- only count the entries, don't inspect them
    force_same_shape -- throw an error if not all images have the same shape
    print_data -- print the array for each datum
    """
    start_time = time.time()

    # Open database
    try:
        database = validate_database_path(database)
    except ValueError as e:
        logger.error(e.message)
        return False

    reader = DbReader(database)
    logger.info('Total entries: %s' % reader.total_entries)

    unique_shapes = Counter()

    count = 0
    update_time = None
    for key, value in reader.entries():
        datum = dataset_pb2.Datum()
        datum.ParseFromString(value)

        if print_data:
            array = datum_to_array(datum)
            print('>>> Datum #%d (shape=%s)' % (count, array.shape))
            print(array)

        if (not datum.HasField('height') or datum.height == 0
                or not datum.HasField('width') or datum.width == 0):
            if datum.encoded:
                if force_same_shape or not len(unique_shapes.keys()):
                    # Decode datum to learn the shape
                    s = StringIO()
                    s.write(datum.data)
                    s.seek(0)
                    img = PIL.Image.open(s)
                    width, height = img.size
                    channels = len(img.split())
                else:
                    # We've already decoded one image, don't bother reading the rest
                    width = '?'
                    height = '?'
                    channels = '?'
            else:
                errstr = 'Shape is not set and datum is not encoded'
                logger.error(errstr)
                raise ValueError(errstr)
        else:
            width, height, channels = datum.width, datum.height, datum.channels

        shape = '%sx%sx%s' % (width, height, channels)

        unique_shapes[shape] += 1

        if force_same_shape and len(unique_shapes.keys()) > 1:
            logger.error("Images with different shapes found: %s and %s" %
                         tuple(unique_shapes.keys()))
            return False

        count += 1
        # Send update every 2 seconds
        if update_time is None or (time.time() - update_time) > 2:
            logger.debug('>>> Key %s' % key)
            print_datum(datum)
            logger.debug('Progress: %s/%s' % (count, reader.total_entries))
            update_time = time.time()

        if only_count:
            # quit after reading one
            count = reader.total_entries
            logger.info('Assuming all entries have same shape ...')
            unique_shapes[list(unique_shapes.keys())[0]] = count
            break

    if count != reader.total_entries:
        logger.warning('LMDB reported %s total entries, but only read %s' %
                       (reader.total_entries, count))

    for key, val in sorted(unique_shapes.items(),
                           key=operator.itemgetter(1),
                           reverse=True):
        logger.info('%s entries found with shape %s (WxHxC)' % (val, key))

    logger.info('Completed in %s seconds.' % (time.time() - start_time, ))
    return True
Exemplo n.º 6
0
def infer(input_list, output_dir, jobs_dir, model_id, epoch, batch_size,
          layers, gpu, input_is_db, resize):
    """
    Perform inference on a list of images using the specified model
    """
    # job directory defaults to that defined in DIGITS config
    if jobs_dir == 'none':
        jobs_dir = digits.config.config_value('jobs_dir')

    # load model job
    model_dir = os.path.join(jobs_dir, model_id)
    assert os.path.isdir(model_dir), "Model dir %s does not exist" % model_dir
    model = Job.load(model_dir)

    # load dataset job
    dataset_dir = os.path.join(jobs_dir, model.dataset_id)
    assert os.path.isdir(
        dataset_dir), "Dataset dir %s does not exist" % dataset_dir
    dataset = Job.load(dataset_dir)
    for task in model.tasks:
        task.dataset = dataset

    # retrieve snapshot file
    task = model.train_task()
    snapshot_filename = None
    epoch = float(epoch)
    if epoch == -1 and len(task.snapshots):
        # use last epoch
        epoch = task.snapshots[-1][1]
        snapshot_filename = task.snapshots[-1][0]
    else:
        for f, e in task.snapshots:
            if e == epoch:
                snapshot_filename = f
                break
    if not snapshot_filename:
        raise InferenceError("Unable to find snapshot for epoch=%s" %
                             repr(epoch))

    # retrieve image dimensions and resize mode
    image_dims = dataset.get_feature_dims()
    height = image_dims[0]
    width = image_dims[1]
    channels = image_dims[2]
    resize_mode = dataset.resize_mode if hasattr(dataset,
                                                 'resize_mode') else 'squash'

    n_input_samples = 0  # number of samples we were able to load
    input_ids = []  # indices of samples within file list
    input_data = []  # sample data

    if input_is_db:
        # load images from database
        reader = DbReader(input_list)
        for key, value in reader.entries():
            datum = dataset_pb2.Datum()
            datum.ParseFromString(value)
            if datum.encoded:
                s = BytesIO()
                s.write(datum.data)
                s.seek(0)
                img = PIL.Image.open(s)
                img = np.array(img)
            else:
                arr = datum_to_array(datum)
                # CHW -> HWC
                arr = arr.transpose((1, 2, 0))
                if arr.shape[2] == 1:
                    # HWC -> HW
                    arr = arr[:, :, 0]
                elif arr.shape[2] == 3:
                    # BGR -> RGB
                    # XXX see issue #59
                    arr = arr[:, :, [2, 1, 0]]
                img = arr
            input_ids.append(key)
            input_data.append(img)
            n_input_samples = n_input_samples + 1
    else:
        # load paths from file
        paths = None
        with open(input_list) as infile:
            paths = infile.readlines()
        # load and resize images
        for idx, path in enumerate(paths):
            path = path.strip()
            try:
                image = utils.image.load_image(path.strip())
                if resize:
                    image = utils.image.resize_image(image,
                                                     height,
                                                     width,
                                                     channels=channels,
                                                     resize_mode=resize_mode)
                else:
                    image = utils.image.image_to_array(image,
                                                       channels=channels)
                input_ids.append(idx)
                input_data.append(image)
                n_input_samples = n_input_samples + 1
            except utils.errors.LoadImageError as e:
                print(e)

    # perform inference
    visualizations = None

    if n_input_samples == 0:
        raise InferenceError("Unable to load any image from file '%s'" %
                             repr(input_list))
    elif n_input_samples == 1:
        # single image inference
        outputs, visualizations = model.train_task().infer_one(
            input_data[0],
            snapshot_epoch=epoch,
            layers=layers,
            gpu=gpu,
            resize=resize)
    else:
        if layers != 'none':
            raise InferenceError(
                "Layer visualization is not supported for multiple inference")
        outputs = model.train_task().infer_many(input_data,
                                                snapshot_epoch=epoch,
                                                gpu=gpu,
                                                resize=resize)

    # write to hdf5 file
    db_path = os.path.join(output_dir, 'inference.hdf5')
    db = h5py.File(db_path, 'w')

    # write input paths and images to database
    db.create_dataset("input_ids", data=input_ids)
    db.create_dataset("input_data", data=input_data)

    # write outputs to database
    db_outputs = db.create_group("outputs")
    for output_id, output_name in enumerate(outputs.keys()):
        output_data = outputs[output_name]
        output_key = base64.urlsafe_b64encode(str(output_name))
        dset = db_outputs.create_dataset(output_key, data=output_data)
        # add ID attribute so outputs can be sorted in
        # the order they appear in here
        dset.attrs['id'] = output_id

    # write visualization data
    if visualizations is not None and len(visualizations) > 0:
        db_layers = db.create_group("layers")
        for idx, layer in enumerate(visualizations):
            vis = layer['vis'] if layer['vis'] is not None else np.empty(0)
            dset = db_layers.create_dataset(str(idx), data=vis)
            dset.attrs['name'] = layer['name']
            dset.attrs['vis_type'] = layer['vis_type']
            if 'param_count' in layer:
                dset.attrs['param_count'] = layer['param_count']
            if 'layer_type' in layer:
                dset.attrs['layer_type'] = layer['layer_type']
            dset.attrs['shape'] = layer['data_stats']['shape']
            dset.attrs['mean'] = layer['data_stats']['mean']
            dset.attrs['stddev'] = layer['data_stats']['stddev']
            dset.attrs['histogram_y'] = layer['data_stats']['histogram'][0]
            dset.attrs['histogram_x'] = layer['data_stats']['histogram'][1]
            dset.attrs['histogram_ticks'] = layer['data_stats']['histogram'][2]
    db.close()
    logger.info('Saved data to %s', db_path)
Exemplo n.º 7
0
def explore():
    """
    Returns a gallery consisting of the images of one of the dbs
    """
    job = job_from_request()
    # Get LMDB
    db = job.path(flask.request.args.get('db'))
    db_path = job.path(db)

    if (os.path.basename(db_path) == 'labels'
            and COLOR_PALETTE_ATTRIBUTE in job.extension_userdata
            and job.extension_userdata[COLOR_PALETTE_ATTRIBUTE]):
        # assume single-channel 8-bit palette
        palette = job.extension_userdata[COLOR_PALETTE_ATTRIBUTE]
        palette = np.array(palette).reshape((len(palette) / 3, 3)) / 255.
        # normalize input pixels to [0,1]
        norm = mpl.colors.Normalize(vmin=0, vmax=255)
        # create map
        cmap = plt.cm.ScalarMappable(norm=norm,
                                     cmap=mpl.colors.ListedColormap(palette))
    else:
        cmap = None

    page = int(flask.request.args.get('page', 0))
    size = int(flask.request.args.get('size', 25))

    reader = DbReader(db_path)
    count = 0
    imgs = []

    min_page = max(0, page - 5)
    total_entries = reader.total_entries

    max_page = min((total_entries - 1) // size, page + 5)
    pages = range(min_page, max_page + 1)
    for key, value in reader.entries():
        if count >= page * size:
            datum = dataset_pb2.Datum()
            datum.ParseFromString(value)
            if not datum.encoded:
                raise RuntimeError("Expected encoded database")
            s = StringIO()
            s.write(datum.data.decode())
            s.seek(0)
            img = PIL.Image.open(s)
            if cmap and img.mode in ['L', '1']:
                data = np.array(img)
                data = cmap.to_rgba(data) * 255
                data = data.astype('uint8')
                # keep RGB values only, remove alpha channel
                data = data[:, :, 0:3]
                img = PIL.Image.fromarray(data)
            imgs.append({
                "label": None,
                "b64": utils.image.embed_image_html(img)
            })
        count += 1
        if len(imgs) >= size:
            break

    return flask.render_template('datasets/images/explore.html',
                                 page=page,
                                 size=size,
                                 job=job,
                                 imgs=imgs,
                                 labels=None,
                                 pages=pages,
                                 label=None,
                                 total_entries=total_entries,
                                 db=db)