def create_gist_database(data_dir, img_dirs): """Create a GIST database""" if (not os.path.isfile(gists_db_name)) or (not os.path.isfile(f_db_name)): gist_data = [] file_names = [] for dir in img_dirs: scene_dir = os.path.join(data_dir, dir) img_files = [f for f in os.listdir(scene_dir) if os.path.isfile(os.path.join(scene_dir, f)) and 'jpg' in f] for f in img_files: img_file = os.path.join(scene_dir, f) print img_file gist, param \ = lmgist.lmgist(img_file, param) gist_data.append(gist) file_names.append(img_file) np.save(gists_db_name, gist_data) np.save(f_db_name, file_names) with open('gistparams', 'w') as f: pickle.dump(param, f) else: print '%s and %s already exist' % (gists_db_name, f_db_name) return
def get_gist(query_name, img_query, img_mask, param): ''' Get the gist descriptor for the query image with mask ''' # get the gist parameter of the source file gist, param = lmgist.lmgist(query_name, param) # resize the query and mask images to match the one used to compute # the GIST descriptor img_query = transform.resize(img_query, (param.img_size, param.img_size), clip=True, order=1) img_mask = transform.resize(img_mask, (param.img_size, param.img_size), clip=True, order=1) # make sure the mask is 2d and in [0, 1] if len(img_mask.shape) != 2: raise TypeError('Mask image needs to be 2D!') elif not ((np.min(img_mask) == 0) and (np.max(img_mask == 1))): img_mask = (img_mask - np.min(img_mask)) \ / (np.max(img_mask) - np.min(img_mask)) # GIST weights are the average value of mask pixels within each block s = (np.linspace(0, param.img_size, param.number_blocks + 1)).astype(int) block_weight = np.zeros((param.number_blocks, param.number_blocks)) for y in range(param.number_blocks): for x in range(param.number_blocks): block = img_mask[s[y]:s[y + 1], s[x]:s[x + 1]] block_weight[y, x] = np.mean(block) block_weight = 1 - block_weight n_filters = sum(param.orientations_per_scale) block_weight = matlib.repmat( np.reshape(block_weight.T, [param.number_blocks ** 2, 1]), n_filters, 1) return gist, block_weight