def _array_to_datum(self, data, scalar_label, encoding): if data.ndim != 3: raise ValueError('Invalid number of dimensions: %d' % data.ndim) if encoding == 'none': if data.shape[0] == 3: # RGB to BGR # XXX see issue #59 data = data[[2, 1, 0], ...] datum = array_to_datum(data, scalar_label) else: # Transpose to (height, width, channel) data = data.transpose((1, 2, 0)) datum = dataset_pb2.Datum() datum.height = data.shape[0] datum.width = data.shape[1] datum.channels = data.shape[2] datum.label = scalar_label if data.shape[2] == 1: # grayscale data = data[:, :, 0] s = BytesIO() if encoding == 'png': PIL.Image.fromarray(data).save(s, format='PNG') elif encoding == 'jpg': PIL.Image.fromarray(data).save(s, format='JPEG', quality=90) else: raise ValueError('Invalid encoding type') datum.data = s.getvalue() datum.encoded = True return datum
def _array_to_datum(image, label, encoding): """ Create a caffe Datum from a numpy.ndarray """ if not encoding: # Transform to caffe's format requirements if image.ndim == 3: # Transpose to (channels, height, width) image = image.transpose((2, 0, 1)) if image.shape[0] == 3: # channel swap # XXX see issue #59 image = image[[2, 1, 0], ...] elif image.ndim == 2: # Add a channels axis image = image[np.newaxis, :, :] else: raise Exception('Image has unrecognized shape: "%s"' % image.shape) datum = array_to_datum(image, label) else: datum = dataset_pb2.Datum() if image.ndim == 3: datum.channels = image.shape[2] else: datum.channels = 1 datum.height = image.shape[0] datum.width = image.shape[1] datum.label = label s = BytesIO() if encoding == 'png': PIL.Image.fromarray(image).save(s, format='PNG') elif encoding == 'jpg': PIL.Image.fromarray(image).save(s, format='JPEG', quality=90) else: raise ValueError('Invalid encoding type') datum.data = s.getvalue() datum.encoded = True return datum
def explore(): """ Returns a gallery consisting of the images of one of the dbs """ job = job_from_request() # Get LMDB db = flask.request.args.get('db', 'train') if 'train' in db.lower(): task = job.train_db_task() elif 'val' in db.lower(): task = job.val_db_task() elif 'test' in db.lower(): task = job.test_db_task() if task is None: raise ValueError('No create_db task for {0}'.format(db)) if task.status != 'D': raise ValueError( "This create_db task's status should be 'D' but is '{0}'".format( task.status)) if task.backend != 'lmdb': raise ValueError( "Backend is {0} while expected backend is lmdb".format( task.backend)) db_path = job.path(task.db_name) labels = task.get_labels() page = int(flask.request.args.get('page', 0)) size = int(flask.request.args.get('size', 25)) label = flask.request.args.get('label', None) if label is not None: try: label = int(label) except ValueError: label = None reader = DbReader(db_path) count = 0 imgs = [] min_page = max(0, page - 5) if label is None: total_entries = reader.total_entries else: # After PR#1500, task.distribution[str(label)] is a dictionary # with keys = 'count' and 'error_count' label_entries = task.distribution[str(label)] if isinstance(label_entries, dict): total_entries = label_entries['count'] else: total_entries = label_entries max_page = min((total_entries - 1) // size, page + 5) pages = range(min_page, max_page + 1) for key, value in reader.entries(): if count >= page * size: datum = dataset_pb2.Datum() datum.ParseFromString(value) if label is None or datum.label == label: if datum.encoded: s = BytesIO() s.write(datum.data) s.seek(0) img = PIL.Image.open(s) else: arr = datum_to_array(datum) # CHW -> HWC arr = arr.transpose((1, 2, 0)) if arr.shape[2] == 1: # HWC -> HW arr = arr[:, :, 0] elif arr.shape[2] == 3: # BGR -> RGB # XXX see issue #59 arr = arr[:, :, [2, 1, 0]] img = PIL.Image.fromarray(arr) imgs.append({ "label": labels[datum.label], "b64": utils.image.embed_image_html(img) }) if label is None: count += 1 else: datum = dataset_pb2.Datum() datum.ParseFromString(value) if datum.label == int(label): count += 1 if len(imgs) >= size: break return flask.render_template('datasets/images/explore.html', page=page, size=size, job=job, imgs=imgs, labels=labels, pages=pages, label=label, total_entries=total_entries, db=db)
def create_lmdbs(folder, image_width=None, image_height=None, image_count=None): """ Creates LMDBs for generic inference Returns the filename for a test image Creates these files in "folder": train_images/ train_labels/ val_images/ val_labels/ mean.binaryproto test.png """ if image_width is None: image_width = IMAGE_SIZE if image_height is None: image_height = IMAGE_SIZE if image_count is None: train_image_count = TRAIN_IMAGE_COUNT else: train_image_count = image_count val_image_count = VAL_IMAGE_COUNT # Used to calculate the gradients later yy, xx = np.mgrid[:image_height, :image_width].astype('float') for phase, image_count in [('train', train_image_count), ('val', val_image_count)]: image_db = lmdb.open(os.path.join(folder, '%s_images' % phase), map_async=True, max_dbs=0) label_db = lmdb.open(os.path.join(folder, '%s_labels' % phase), map_async=True, max_dbs=0) image_sum = np.zeros((image_height, image_width), 'float64') for i in range(image_count): xslope, yslope = np.random.random_sample(2) - 0.5 a = xslope * 255 / image_width b = yslope * 255 / image_height image = a * (xx - image_width / 2) + b * (yy - image_height / 2) + 127.5 image_sum += image image = image.astype('uint8') pil_img = PIL.Image.fromarray(image) # create image Datum image_datum = dataset_pb2.Datum() image_datum.height = image.shape[0] image_datum.width = image.shape[1] image_datum.channels = 1 s = StringIO() pil_img.save(str(s), format='PNG') image_datum.data = s.getvalue().encode('utf-8') image_datum.encoded = True _write_to_lmdb(image_db, str(i), image_datum.SerializeToString()) # create label Datum label_datum = dataset_pb2.Datum() label_datum.channels, label_datum.height, label_datum.width = 1, 1, 2 label_datum.float_data.extend(np.array([xslope, yslope]).flat) _write_to_lmdb(label_db, str(i), label_datum.SerializeToString()) # close databases image_db.close() label_db.close() # save mean mean_image = (image_sum / image_count).astype('uint8') _save_mean(mean_image, os.path.join(folder, '%s_mean.png' % phase)) _save_mean(mean_image, os.path.join(folder, '%s_mean.binaryproto' % phase)) # create test image # The network should be able to easily produce two numbers >1 xslope, yslope = 0.5, 0.5 a = xslope * 255 / image_width b = yslope * 255 / image_height test_image = a * (xx - image_width / 2) + b * (yy - image_height / 2) + 127.5 test_image = test_image.astype('uint8') pil_img = PIL.Image.fromarray(test_image) test_image_filename = os.path.join(folder, 'test.png') pil_img.save(test_image_filename) return test_image_filename
def analyze_db( database, only_count=False, force_same_shape=False, print_data=False, ): """ Looks at the data in a prebuilt database and verifies it Also prints out some information about it Returns True if all entries are valid Arguments: database -- path to the database Keyword arguments: only_count -- only count the entries, don't inspect them force_same_shape -- throw an error if not all images have the same shape print_data -- print the array for each datum """ start_time = time.time() # Open database try: database = validate_database_path(database) except ValueError as e: logger.error(e.message) return False reader = DbReader(database) logger.info('Total entries: %s' % reader.total_entries) unique_shapes = Counter() count = 0 update_time = None for key, value in reader.entries(): datum = dataset_pb2.Datum() datum.ParseFromString(value) if print_data: array = datum_to_array(datum) print('>>> Datum #%d (shape=%s)' % (count, array.shape)) print(array) if (not datum.HasField('height') or datum.height == 0 or not datum.HasField('width') or datum.width == 0): if datum.encoded: if force_same_shape or not len(unique_shapes.keys()): # Decode datum to learn the shape s = StringIO() s.write(datum.data) s.seek(0) img = PIL.Image.open(s) width, height = img.size channels = len(img.split()) else: # We've already decoded one image, don't bother reading the rest width = '?' height = '?' channels = '?' else: errstr = 'Shape is not set and datum is not encoded' logger.error(errstr) raise ValueError(errstr) else: width, height, channels = datum.width, datum.height, datum.channels shape = '%sx%sx%s' % (width, height, channels) unique_shapes[shape] += 1 if force_same_shape and len(unique_shapes.keys()) > 1: logger.error("Images with different shapes found: %s and %s" % tuple(unique_shapes.keys())) return False count += 1 # Send update every 2 seconds if update_time is None or (time.time() - update_time) > 2: logger.debug('>>> Key %s' % key) print_datum(datum) logger.debug('Progress: %s/%s' % (count, reader.total_entries)) update_time = time.time() if only_count: # quit after reading one count = reader.total_entries logger.info('Assuming all entries have same shape ...') unique_shapes[list(unique_shapes.keys())[0]] = count break if count != reader.total_entries: logger.warning('LMDB reported %s total entries, but only read %s' % (reader.total_entries, count)) for key, val in sorted(unique_shapes.items(), key=operator.itemgetter(1), reverse=True): logger.info('%s entries found with shape %s (WxHxC)' % (val, key)) logger.info('Completed in %s seconds.' % (time.time() - start_time, )) return True
def infer(input_list, output_dir, jobs_dir, model_id, epoch, batch_size, layers, gpu, input_is_db, resize): """ Perform inference on a list of images using the specified model """ # job directory defaults to that defined in DIGITS config if jobs_dir == 'none': jobs_dir = digits.config.config_value('jobs_dir') # load model job model_dir = os.path.join(jobs_dir, model_id) assert os.path.isdir(model_dir), "Model dir %s does not exist" % model_dir model = Job.load(model_dir) # load dataset job dataset_dir = os.path.join(jobs_dir, model.dataset_id) assert os.path.isdir( dataset_dir), "Dataset dir %s does not exist" % dataset_dir dataset = Job.load(dataset_dir) for task in model.tasks: task.dataset = dataset # retrieve snapshot file task = model.train_task() snapshot_filename = None epoch = float(epoch) if epoch == -1 and len(task.snapshots): # use last epoch epoch = task.snapshots[-1][1] snapshot_filename = task.snapshots[-1][0] else: for f, e in task.snapshots: if e == epoch: snapshot_filename = f break if not snapshot_filename: raise InferenceError("Unable to find snapshot for epoch=%s" % repr(epoch)) # retrieve image dimensions and resize mode image_dims = dataset.get_feature_dims() height = image_dims[0] width = image_dims[1] channels = image_dims[2] resize_mode = dataset.resize_mode if hasattr(dataset, 'resize_mode') else 'squash' n_input_samples = 0 # number of samples we were able to load input_ids = [] # indices of samples within file list input_data = [] # sample data if input_is_db: # load images from database reader = DbReader(input_list) for key, value in reader.entries(): datum = dataset_pb2.Datum() datum.ParseFromString(value) if datum.encoded: s = BytesIO() s.write(datum.data) s.seek(0) img = PIL.Image.open(s) img = np.array(img) else: arr = datum_to_array(datum) # CHW -> HWC arr = arr.transpose((1, 2, 0)) if arr.shape[2] == 1: # HWC -> HW arr = arr[:, :, 0] elif arr.shape[2] == 3: # BGR -> RGB # XXX see issue #59 arr = arr[:, :, [2, 1, 0]] img = arr input_ids.append(key) input_data.append(img) n_input_samples = n_input_samples + 1 else: # load paths from file paths = None with open(input_list) as infile: paths = infile.readlines() # load and resize images for idx, path in enumerate(paths): path = path.strip() try: image = utils.image.load_image(path.strip()) if resize: image = utils.image.resize_image(image, height, width, channels=channels, resize_mode=resize_mode) else: image = utils.image.image_to_array(image, channels=channels) input_ids.append(idx) input_data.append(image) n_input_samples = n_input_samples + 1 except utils.errors.LoadImageError as e: print(e) # perform inference visualizations = None if n_input_samples == 0: raise InferenceError("Unable to load any image from file '%s'" % repr(input_list)) elif n_input_samples == 1: # single image inference outputs, visualizations = model.train_task().infer_one( input_data[0], snapshot_epoch=epoch, layers=layers, gpu=gpu, resize=resize) else: if layers != 'none': raise InferenceError( "Layer visualization is not supported for multiple inference") outputs = model.train_task().infer_many(input_data, snapshot_epoch=epoch, gpu=gpu, resize=resize) # write to hdf5 file db_path = os.path.join(output_dir, 'inference.hdf5') db = h5py.File(db_path, 'w') # write input paths and images to database db.create_dataset("input_ids", data=input_ids) db.create_dataset("input_data", data=input_data) # write outputs to database db_outputs = db.create_group("outputs") for output_id, output_name in enumerate(outputs.keys()): output_data = outputs[output_name] output_key = base64.urlsafe_b64encode(str(output_name)) dset = db_outputs.create_dataset(output_key, data=output_data) # add ID attribute so outputs can be sorted in # the order they appear in here dset.attrs['id'] = output_id # write visualization data if visualizations is not None and len(visualizations) > 0: db_layers = db.create_group("layers") for idx, layer in enumerate(visualizations): vis = layer['vis'] if layer['vis'] is not None else np.empty(0) dset = db_layers.create_dataset(str(idx), data=vis) dset.attrs['name'] = layer['name'] dset.attrs['vis_type'] = layer['vis_type'] if 'param_count' in layer: dset.attrs['param_count'] = layer['param_count'] if 'layer_type' in layer: dset.attrs['layer_type'] = layer['layer_type'] dset.attrs['shape'] = layer['data_stats']['shape'] dset.attrs['mean'] = layer['data_stats']['mean'] dset.attrs['stddev'] = layer['data_stats']['stddev'] dset.attrs['histogram_y'] = layer['data_stats']['histogram'][0] dset.attrs['histogram_x'] = layer['data_stats']['histogram'][1] dset.attrs['histogram_ticks'] = layer['data_stats']['histogram'][2] db.close() logger.info('Saved data to %s', db_path)
def explore(): """ Returns a gallery consisting of the images of one of the dbs """ job = job_from_request() # Get LMDB db = job.path(flask.request.args.get('db')) db_path = job.path(db) if (os.path.basename(db_path) == 'labels' and COLOR_PALETTE_ATTRIBUTE in job.extension_userdata and job.extension_userdata[COLOR_PALETTE_ATTRIBUTE]): # assume single-channel 8-bit palette palette = job.extension_userdata[COLOR_PALETTE_ATTRIBUTE] palette = np.array(palette).reshape((len(palette) / 3, 3)) / 255. # normalize input pixels to [0,1] norm = mpl.colors.Normalize(vmin=0, vmax=255) # create map cmap = plt.cm.ScalarMappable(norm=norm, cmap=mpl.colors.ListedColormap(palette)) else: cmap = None page = int(flask.request.args.get('page', 0)) size = int(flask.request.args.get('size', 25)) reader = DbReader(db_path) count = 0 imgs = [] min_page = max(0, page - 5) total_entries = reader.total_entries max_page = min((total_entries - 1) // size, page + 5) pages = range(min_page, max_page + 1) for key, value in reader.entries(): if count >= page * size: datum = dataset_pb2.Datum() datum.ParseFromString(value) if not datum.encoded: raise RuntimeError("Expected encoded database") s = StringIO() s.write(datum.data.decode()) s.seek(0) img = PIL.Image.open(s) if cmap and img.mode in ['L', '1']: data = np.array(img) data = cmap.to_rgba(data) * 255 data = data.astype('uint8') # keep RGB values only, remove alpha channel data = data[:, :, 0:3] img = PIL.Image.fromarray(data) imgs.append({ "label": None, "b64": utils.image.embed_image_html(img) }) count += 1 if len(imgs) >= size: break return flask.render_template('datasets/images/explore.html', page=page, size=size, job=job, imgs=imgs, labels=None, pages=pages, label=None, total_entries=total_entries, db=db)