示例#1
0
def do_search(image_encoder, table_name, img_path, top_k):
    try:
        #print(top_k)
        detector = Detector()
        run(detector, img_path)
        vect, obj_images = get_object_vector(image_encoder,
                                             img_path + '/object')
        #print("search...after detect:", len(vect), obj_images)
        index_client = milvus_client()
        #vect = normaliz_vec(vect)
        status, results = search_vectors(index_client, table_name, vect, top_k)
        # print(status, results)
        vids = []
        dis = []
        for result in results:
            for j in result:
                vids.append(j.id)
                dis.append(j.distance)
        res_id = [x for x in query_name_from_ids(vids)]
        #print("------------------res", vids, dis, res_id)
        shutil.rmtree(img_path + '/object')
        return res_id, dis
    except Exception as e:
        logging.error(e)
        return "Fail with error {}".format(e)
示例#2
0
def do_search_logo(image_encoder, index_client, conn, cursor, table_name, filename):
    detector = Detector()
    if not table_name:
        table_name = LOGO_TABLE

    prefix = filename.split("/")[2].split(".")[0] + "-" + uuid.uuid4().hex
    images = extract_frame(filename, 1, prefix)
    run(detector, DATA_PATH + '/' + prefix)

    vectors, obj_images = get_object_vector(image_encoder, DATA_PATH + '/' + prefix + '/object')
    results = search_vectors(index_client, table_name, vectors, "L2")

    info, times = get_object_info(conn, cursor, table_name, results, obj_images)
    return info, times
示例#3
0
def do_search(table_name, img_path, model, milvus_client, mysql_cli):
    try:
        if not table_name:
            table_name = DEFAULT_TABLE
        detector = Detector()
        run(detector, img_path)
        vecs = get_object_vector(model, img_path + '/object')
        # feat = model.resnet50_extract_feat(img_path)
        results = milvus_client.search_vectors(table_name, vecs, TOP_K)
        ids = []
        distances = []
        for result in results:
            for j in result:
                ids.append(j.id)
                distances.append(j.distance)
        # res_id = [x for x in query_name_from_ids(vids)]
        # vids = [str(x.id) for x in vectors[0]]
        paths = mysql_cli.search_by_milvus_ids(ids, table_name)
        # distances = [x.distance for x in vectors[0]]
        shutil.rmtree(img_path)
        return paths, distances
    except Exception as e:
        LOGGER.error(" Error with search : {}".format(e))
        sys.exit(1)
示例#4
0
def do_load(table_name, database_path, model, mil_cli, mysql_cli):
    detector = Detector()
    if not table_name:
        table_name = DEFAULT_TABLE
    cache = Cache(CACHE_DIR)
    result_images, object_num = run(detector, database_path)
    vectors = get_object_vector(cache, model, database_path + "/object")
    ids = mil_cli.insert(table_name, vectors)
    mil_cli.create_index(table_name)
    shutil.rmtree(database_path + "/object")
    imgs = get_imgs_path(database_path)
    imgs.sort()
    matched_imgs = match_ids_and_imgs(imgs, object_num)
    mysql_cli.create_mysql_table(table_name)
    mysql_cli.load_data_to_mysql(table_name, format_data(ids, matched_imgs))
    return len(ids)
示例#5
0
def do_train(table_name, database_path):
    detector = Detector()
    if not table_name:
        table_name = DEFAULT_TABLE
    cache = Cache(default_cache_dir)
    try:
        result_images, object_num = run(detector, database_path)
        #print("after detect:", object_num)
        vectors, obj_images = get_object_vector(cache, image_encoder,
                                                database_path + "/object")
        #print("after detect:", len(vectors), obj_images)
        index_client = milvus_client()
        status, ok = has_table(index_client, table_name)
        if not ok:
            print("create table.")
            create_table(index_client, table_name=table_name)
        print("insert into:", table_name)
        # vectors = normaliz_vec(vectors)
        status, ids = insert_vectors(index_client, table_name, vectors)
        #print(status,ids)
        create_index(index_client, table_name)
        shutil.rmtree(database_path + "/object")
        imgs = os.listdir(database_path)
        imgs.sort()
        #print("-----imgs", imgs)
        k = 0
        ids = list(reversed(ids))
        #print("ids", ids)
        for num in object_num:
            for i in range(num):
                a = ids.pop()
                #print("real;;;;;;;;;",a, imgs[k])
                cache[a] = imgs[k]
            k += 1
        return print("train finished")
    except Exception as e:
        logging.error(e)
        return "Error with {}".format(e)