Ejemplo n.º 1
0
 def _get_feat(self, db, f_class):
     if f_class == 'color':
         f_c = Color()
     elif f_class == 'daisy':
         f_c = Daisy()
     elif f_class == 'edge':
         f_c = Edge()
     elif f_class == 'gabor':
         f_c = Gabor()
     elif f_class == 'hog':
         f_c = HOG()
     elif f_class == 'vgg':
         f_c = VGGNetFeat()
     elif f_class == 'res':
         f_c = ResNetFeat()
     return f_c.make_samples(db, verbose=False)
    def __init__(self, db, f_class=None, d_type='L1'):
        self.NGT_dir = 'NGT_{}_{}'.format(f_class,d_type)
        self.NGT_path = b''
        self.fearure = f_class
        self.SQLdb = SQLite()

        if f_class == 'daisy':
            self.f_c = Daisy()
            self.NGT_path = b'NGT/NGT_daisy_'+d_type.encode()
        elif f_class == 'edge':
            self.f_c = Edge()
            self.NGT_path = b'NGT/NGT_edge_'+d_type.encode()
        elif f_class == 'hog':
            self.f_c = HOG()
            self.NGT_path = b'NGT/NGT_hog_'+d_type.encode()
        elif f_class == 'vgg':
            self.f_c = VGGNetFeat()
            self.NGT_path = b'NGT/NGT_vgg_'+d_type.encode()
        elif f_class == 'res':
            self.f_c = ResNetFeat()
            self.NGT_path = b'NGT/NGT_res_'+d_type.encode()
        if not os.path.exists(os.path.join(NGT_dir,self.NGT_dir)):
                samples = self.f_c.make_samples(db, verbose=False)
                dim = 0
                try: 
                    dim = samples[0]['hist'].shape[0]
                except:
                    pass
                images= []
                objects = []
                for i, row in enumerate(samples):
                    vector  = row['hist']
                    link    = row['img']
                    lable   = row['cls']
                    data = {'index':i,'link':link,'lable':lable}
                    images.append(data)
                    objects.append(vector)
                self.SQLdb.updateMuti(f_class,images)

                # cPickle.dump(images, open(os.path.join(NGT_dir, sample_cache), "wb", True))
                ngtpy.create(path=self.NGT_path, dimension=dim, distance_type=d_type)
                self.index = ngtpy.Index(self.NGT_path)
                self.index.batch_insert(objects)
                self.index.save()

        self.index  = ngtpy.Index(self.NGT_path) 
Ejemplo n.º 3
0
    # print(result)

    # # retrieve by gabor
    # method = Gabor()
    # samples = method.make_samples(db)
    # query = samples[query_idx]
    # _, result = infer(query, samples=samples, depth=depth, d_type=d_type)
    # print(result)

    # # retrieve by HOG
    # method = HOG()
    # samples = method.make_samples(db)
    # query = samples[query_idx]
    # _, result = infer(query, samples=samples, depth=depth, d_type=d_type)
    # print(result)

    # # retrieve by VGG
    # method = VGGNetFeat()
    # samples = method.make_samples(db)
    # query = samples[query_idx]
    # _, result = infer(query, samples=samples, depth=depth, d_type=d_type)
    # print(result)

    # retrieve by resnet
    method = ResNetFeat()
    samples = method.make_samples(db)
    query_idx = 1569
    query = samples[query_idx]
    _, result = infer(query, samples=samples, depth=depth, d_type=d_type)
    print(result)
Ejemplo n.º 4
0
    _, result = infer(query, samples=samples, depth=depth, d_type=d_type)
    print(result)

    # retrieve by gabor
    method = Gabor()
    samples = method.make_samples(db)
    query = samples[query_idx]
    _, result = infer(query, samples=samples, depth=depth, d_type=d_type)
    print(result)

    # retrieve by HOG
    method = HOG()
    samples = method.make_samples(db)
    query = samples[query_idx]
    _, result = infer(query, samples=samples, depth=depth, d_type=d_type)
    print(result)

    # retrieve by VGG
    method = VGGNetFeat()
    samples = method.make_samples(db)
    query = samples[query_idx]
    _, result = infer(query, samples=samples, depth=depth, d_type=d_type)
    print(result)

    # retrieve by resnet
    method = ResNetFeat()
    samples = method.make_samples(db)
    query = samples[query_idx]
    _, result = infer(query, samples=samples, depth=depth, d_type=d_type)
    print(result)
Ejemplo n.º 5
0
# -*- coding: utf-8 -*-

from __future__ import print_function

from evaluate import infer
from DB import Database
from resnet import ResNetFeat

depth = 5
d_type = 'd1'
query_idx = 0

if __name__ == '__main__':
    db = Database()

    # retrieve by resnet
    method = ResNetFeat()
    query = method.getFeatureQuery('queries/image_06736.jpg')
    samples = method.make_samples(db)
    _, result = infer(query, samples=samples, depth=depth, d_type=d_type)
    print(result)
Ejemplo n.º 6
0
def test(db, query_idx):
    results = {}

    # retrieve by color
    method = Color()
    samples = method.make_samples(db)
    query = samples[query_idx]
    # print(samples)
    img = scipy.misc.imread(query['img'])
    # print(query)
    _, result = infer(query, samples=samples, depth=depth, d_type=d_type)
    # results.append(result[0]['cls'])
    inc(results, result[0]['cls'])

    # # retrieve by daisy
    # method = Daisy()
    # samples = method.make_samples(db)
    # query = samples[query_idx]
    # _, result = infer(query, samples=samples, depth=depth, d_type=d_type)
    # # results.append(result[0]['cls'])
    # inc(results, result[0]['cls'])

    # # # retrieve by edge
    method = Edge()
    samples = method.make_samples(db)
    query = samples[query_idx]
    # print(samples)
    query = samples[query_idx]
    img = scipy.misc.imread(query['img'])
    _, result = infer(query, samples=samples, depth=depth, d_type=d_type)
    # results.append(result[0]['cls'])
    inc(results, result[0]['cls'])

    # # # retrieve by gabor
    # # method = Gabor()
    # # samples = method.make_samples(db)
    # # query = samples[query_idx]
    # # _, result = infer(query, samples=samples, depth=depth, d_type=d_type)
    # # print(result)
    # # inc(results, result[0]['cls'])

    # # retrieve by HOG
    method = HOG()
    samples = method.make_samples(db)
    query = samples[query_idx]
    _, result = infer(query, samples=samples, depth=depth, d_type=d_type)
    # results.append(result[0]['cls'])
    inc(results, result[0]['cls'])

    # # retrieve by VGG
    method = VGGNetFeat()
    samples = method.make_samples(db)
    query = samples[query_idx]
    _, result = infer(query, samples=samples, depth=depth, d_type=d_type)
    # results.append(result[0]['cls'])
    inc(results, result[0]['cls'])

    # # retrieve by resnet
    method = ResNetFeat()
    samples = method.make_samples(db)
    query = samples[query_idx]
    _, result = infer(query, samples=samples, depth=depth, d_type=d_type)
    # results.append(result[0]['cls'])
    inc(results, result[0]['cls'])

    import os
    from PIL import Image
    print(results)
    finalresult = max(results.items(), key=operator.itemgetter(1))[0]
    #string=".../database/"+finalresult+"/"
    string = "./database/" + finalresult + "/"
    print(string)
    a = 1
    for file in os.listdir(string):
        a += 1
        tempimg = Image.open(string + file)
        tempimg.show()

        print(string + file)
        if (a == 10):
            break
    print("Final result is: ", finalresult)

    scipy.misc.imshow(img)
class NGT(object):
    
    def __init__(self, db, f_class=None, d_type='L1'):
        self.NGT_dir = 'NGT_{}_{}'.format(f_class,d_type)
        self.NGT_path = b''
        self.fearure = f_class
        self.SQLdb = SQLite()

        if f_class == 'daisy':
            self.f_c = Daisy()
            self.NGT_path = b'NGT/NGT_daisy_'+d_type.encode()
        elif f_class == 'edge':
            self.f_c = Edge()
            self.NGT_path = b'NGT/NGT_edge_'+d_type.encode()
        elif f_class == 'hog':
            self.f_c = HOG()
            self.NGT_path = b'NGT/NGT_hog_'+d_type.encode()
        elif f_class == 'vgg':
            self.f_c = VGGNetFeat()
            self.NGT_path = b'NGT/NGT_vgg_'+d_type.encode()
        elif f_class == 'res':
            self.f_c = ResNetFeat()
            self.NGT_path = b'NGT/NGT_res_'+d_type.encode()
        if not os.path.exists(os.path.join(NGT_dir,self.NGT_dir)):
                samples = self.f_c.make_samples(db, verbose=False)
                dim = 0
                try: 
                    dim = samples[0]['hist'].shape[0]
                except:
                    pass
                images= []
                objects = []
                for i, row in enumerate(samples):
                    vector  = row['hist']
                    link    = row['img']
                    lable   = row['cls']
                    data = {'index':i,'link':link,'lable':lable}
                    images.append(data)
                    objects.append(vector)
                self.SQLdb.updateMuti(f_class,images)

                # cPickle.dump(images, open(os.path.join(NGT_dir, sample_cache), "wb", True))
                ngtpy.create(path=self.NGT_path, dimension=dim, distance_type=d_type)
                self.index = ngtpy.Index(self.NGT_path)
                self.index.batch_insert(objects)
                self.index.save()

        self.index  = ngtpy.Index(self.NGT_path) 


        
        
    
    def search (self, link ,depth=5):
        query=self.f_c.get_featInput(link)
        r = self.index.search(query, depth)# result[index,square]
        results = []
        for item in r :
            id = item[0]
            results.append(self.SQLdb.select(self.fearure,id))
        return results
    
    def add (self, objects):
        index = self.index
        ids = index.insert(objects)
        index.build_index()
        index.save()
        index.close()
        return ids

    def remove (self, id):
        index = self.index
        index.remove(id)
        index.save()
        index.close()
        return 0
Ejemplo n.º 8
0
            raise Exception
        return name, weight
    except:
        raise argparse.ArgumentTypeError(
            f"\nFeature must be 'name:weight'\n\tname in {features.keys()}\n\tweight >= 1"
        )


features = {
    "color": Color(),
    "daisy": Daisy(),
    "edge": Edge(),
    "gabor": Gabor(),
    "hog": HOG(),
    "vgg": VGGNetFeat(),
    "res": ResNetFeat(),
}

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-n",
                        "--neighbor",
                        help="neighbor by class",
                        type=int,
                        default=3)
    parser.add_argument(
        "-c",
        help="Copy images in a result path (src/CBIR/result/retrieval/)",
        action="store_true")
    egroup = parser.add_mutually_exclusive_group(required=True)
    egroup.add_argument(