def _get_feat(self, db, f_class): if f_class == 'color': f_c = Color() elif f_class == 'daisy': f_c = Daisy() elif f_class == 'edge': f_c = Edge() elif f_class == 'gabor': f_c = Gabor() elif f_class == 'hog': f_c = HOG() elif f_class == 'vgg': f_c = VGGNetFeat() elif f_class == 'res': f_c = ResNetFeat() return f_c.make_samples(db, verbose=False)
def __init__(self, db, f_class=None, d_type='L1'): self.NGT_dir = 'NGT_{}_{}'.format(f_class,d_type) self.NGT_path = b'' self.fearure = f_class self.SQLdb = SQLite() if f_class == 'daisy': self.f_c = Daisy() self.NGT_path = b'NGT/NGT_daisy_'+d_type.encode() elif f_class == 'edge': self.f_c = Edge() self.NGT_path = b'NGT/NGT_edge_'+d_type.encode() elif f_class == 'hog': self.f_c = HOG() self.NGT_path = b'NGT/NGT_hog_'+d_type.encode() elif f_class == 'vgg': self.f_c = VGGNetFeat() self.NGT_path = b'NGT/NGT_vgg_'+d_type.encode() elif f_class == 'res': self.f_c = ResNetFeat() self.NGT_path = b'NGT/NGT_res_'+d_type.encode() if not os.path.exists(os.path.join(NGT_dir,self.NGT_dir)): samples = self.f_c.make_samples(db, verbose=False) dim = 0 try: dim = samples[0]['hist'].shape[0] except: pass images= [] objects = [] for i, row in enumerate(samples): vector = row['hist'] link = row['img'] lable = row['cls'] data = {'index':i,'link':link,'lable':lable} images.append(data) objects.append(vector) self.SQLdb.updateMuti(f_class,images) # cPickle.dump(images, open(os.path.join(NGT_dir, sample_cache), "wb", True)) ngtpy.create(path=self.NGT_path, dimension=dim, distance_type=d_type) self.index = ngtpy.Index(self.NGT_path) self.index.batch_insert(objects) self.index.save() self.index = ngtpy.Index(self.NGT_path)
_, result = infer(query, samples=samples, depth=depth, d_type=d_type) print(result) # retrieve by gabor method = Gabor() samples = method.make_samples(db) query = samples[query_idx] _, result = infer(query, samples=samples, depth=depth, d_type=d_type) print(result) # retrieve by HOG method = HOG() samples = method.make_samples(db) query = samples[query_idx] _, result = infer(query, samples=samples, depth=depth, d_type=d_type) print(result) # retrieve by VGG method = VGGNetFeat() samples = method.make_samples(db) query = samples[query_idx] _, result = infer(query, samples=samples, depth=depth, d_type=d_type) print(result) # retrieve by resnet method = ResNetFeat() samples = method.make_samples(db) query = samples[query_idx] _, result = infer(query, samples=samples, depth=depth, d_type=d_type) print(result)
def test(db, query_idx): results = {} # retrieve by color method = Color() samples = method.make_samples(db) query = samples[query_idx] # print(samples) img = scipy.misc.imread(query['img']) # print(query) _, result = infer(query, samples=samples, depth=depth, d_type=d_type) # results.append(result[0]['cls']) inc(results, result[0]['cls']) # # retrieve by daisy # method = Daisy() # samples = method.make_samples(db) # query = samples[query_idx] # _, result = infer(query, samples=samples, depth=depth, d_type=d_type) # # results.append(result[0]['cls']) # inc(results, result[0]['cls']) # # # retrieve by edge method = Edge() samples = method.make_samples(db) query = samples[query_idx] # print(samples) query = samples[query_idx] img = scipy.misc.imread(query['img']) _, result = infer(query, samples=samples, depth=depth, d_type=d_type) # results.append(result[0]['cls']) inc(results, result[0]['cls']) # # # retrieve by gabor # # method = Gabor() # # samples = method.make_samples(db) # # query = samples[query_idx] # # _, result = infer(query, samples=samples, depth=depth, d_type=d_type) # # print(result) # # inc(results, result[0]['cls']) # # retrieve by HOG method = HOG() samples = method.make_samples(db) query = samples[query_idx] _, result = infer(query, samples=samples, depth=depth, d_type=d_type) # results.append(result[0]['cls']) inc(results, result[0]['cls']) # # retrieve by VGG method = VGGNetFeat() samples = method.make_samples(db) query = samples[query_idx] _, result = infer(query, samples=samples, depth=depth, d_type=d_type) # results.append(result[0]['cls']) inc(results, result[0]['cls']) # # retrieve by resnet method = ResNetFeat() samples = method.make_samples(db) query = samples[query_idx] _, result = infer(query, samples=samples, depth=depth, d_type=d_type) # results.append(result[0]['cls']) inc(results, result[0]['cls']) import os from PIL import Image print(results) finalresult = max(results.items(), key=operator.itemgetter(1))[0] #string=".../database/"+finalresult+"/" string = "./database/" + finalresult + "/" print(string) a = 1 for file in os.listdir(string): a += 1 tempimg = Image.open(string + file) tempimg.show() print(string + file) if (a == 10): break print("Final result is: ", finalresult) scipy.misc.imshow(img)
if name not in features.keys() or weight < 1: raise Exception return name, weight except: raise argparse.ArgumentTypeError( f"\nFeature must be 'name:weight'\n\tname in {features.keys()}\n\tweight >= 1" ) features = { "color": Color(), "daisy": Daisy(), "edge": Edge(), "gabor": Gabor(), "hog": HOG(), "vgg": VGGNetFeat(), "res": ResNetFeat(), } if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-n", "--neighbor", help="neighbor by class", type=int, default=3) parser.add_argument( "-c", help="Copy images in a result path (src/CBIR/result/retrieval/)", action="store_true") egroup = parser.add_mutually_exclusive_group(required=True)