Beispiel #1
0
def create_table(_table_name):
    milvus = Milvus()
    milvus.connect(host="localhost", port="19530")
    if milvus.has_table(_table_name):
        print(f"Table {_table_name} found, now going to delete it")
        status = milvus.delete_table(_table_name)
        assert status.OK(), "delete table {} failed".format(_table_name)

    time.sleep(5)

    if milvus.has_table(_table_name):
        raise Exception("Delete table error")

    print("delete table {} successfully!".format(_table_name))

    # wait for table deleted

    status = milvus.create_table(param)
    if not status.OK():
        print("Create table {} failed".format(_table_name))

    # in main process, milvus must be closed before subprocess start
    milvus.disconnect()

    time.sleep(1)
def validate_insert(_table_name):
    milvus = Milvus()
    milvus.connect(**server_config)

    status, count = milvus.count_table(_table_name)
    assert count == 10 * 10000, "Insert validate fail. Vectors num is not matched."
    milvus.disconnect()
    def _add():
        milvus = Milvus()
        status = milvus.connect(**server_config)

        vectors = _generate_vectors(128, 10000)
        print('\n\tPID: {}, insert {} vectors'.format(os.getpid(), 10000))
        status, _ = milvus.add_vectors(_table_name, vectors)

        milvus.disconnect()
Beispiel #4
0
def validate_insert(_table_name):
    milvus = Milvus()
    milvus.connect(host="localhost", port="19530")

    status, count = milvus.get_table_row_count(_table_name)

    assert count == vector_num * process_num, f"Error: validate insert not pass: "******"{vector_num * process_num} expected but {count} instead!"

    milvus.disconnect()
Beispiel #5
0
    def _add():
        milvus = Milvus()
        status = milvus.connect()

        if not status.OK:
            print(f'PID: {os.getpid()}, connect failed')

        status, _ = milvus.add_vectors(_table_name, vectors)

        milvus.disconnect()
Beispiel #6
0
 def test_connect_connected(self, args):
     '''
     target: test connect and disconnect with corrent ip and port value, assert connected value
     method: set correct ip and port
     expected: connected is False        
     '''
     milvus = Milvus()
     milvus.connect(host=args["ip"], port=args["port"])
     milvus.disconnect()
     assert not milvus.connected()
Beispiel #7
0
    def test_connect_disconnect_repeatedly_once(self, args):
        '''
        target: test connect and disconnect repeatedly
        method: disconnect, and then connect, assert connect status
        expected: status.code is 0
        '''
        milvus = Milvus()
        uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
        milvus.connect(uri=uri_value)

        milvus.disconnect()
        milvus.connect(uri=uri_value)
        assert milvus.connected()
Beispiel #8
0
 def test_connect_disconnect_repeatedly_times(self, args):
     '''
     target: test connect and disconnect for 10 times repeatedly
     method: disconnect, and then connect, assert connect status
     expected: status.code is 0
     '''
     times = 10
     milvus = Milvus()
     uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
     milvus.connect(uri=uri_value)
     for i in range(times):
         milvus.disconnect()
         milvus.connect(uri=uri_value)
     assert milvus.connected()
Beispiel #9
0
def dis_connect(request):
    ip = request.config.getoption("--ip")
    port = request.config.getoption("--port")
    milvus = Milvus()
    milvus.connect(host=ip, port=port)
    milvus.disconnect()
    def fin():
        try:
            milvus.disconnect()
        except:
            pass

    request.addfinalizer(fin)
    return milvus
Beispiel #10
0
 def test_disconnect_repeatedly(self, connect, args):
     '''
     target: test disconnect repeatedly
     method: disconnect a connected client, disconnect again
     expected: raise an error after disconnected
     '''
     if not connect.connected():
         milvus = Milvus()
         uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
         milvus.connect(uri=uri_value)
         res = milvus.disconnect()
         with pytest.raises(Exception) as e:
             res = milvus.disconnect()
     else:
         res = connect.disconnect()
         with pytest.raises(Exception) as e:
             res = connect.disconnect()
Beispiel #11
0
    def multi_thread_opr(table_name, utid):
        print("[{}] | T{} | Running .....".format(datetime.datetime.now(),
                                                  utid))

        client0 = Milvus(handler="HTTP")

        table_param = {'table_name': table_name, 'dimension': 64}

        vectors = [[random.random() for _ in range(64)] for _ in range(10000)]

        client0.connect()
        client0.create_table(table_param)

        # print("[{}] | T{} | O{} | Start insert data .....".format(datetime.datetime.now(), utid, i))
        client0.insert(table_name, vectors)
        # print("[{}] | T{} | O{} | Stop insert data .....".format(datetime.datetime.now(), utid, i))

        client0.disconnect()
Beispiel #12
0
def main():
    milvus = Milvus()
    milvus.connect(host=_HOST, port=_PORT)
    #
    # table_name = 'test_search_in_file'
    # dimension = 256

    # vectors = Prepare.records([[random.random()for _ in range(dimension)] for _ in range(20)])
    # param = {
    #     'table_name': table_name,
    #     'file_ids': ['1'],
    #     'query_records': vectors,
    #     'top_k': 5,
    #     # 'query_ranges': []  # Not fully tested yet
    # }
    # status, result = milvus.search_vectors_in_files(**param)
    # if status.OK():
    #     pprint(result)
    #
    # _, result = milvus.get_table_row_count(table_name)
    # print('# Count: {}'.format(result))

    table_name = 'test_search'
    dimension = 256
    # param = {'start_date': '2019-06-24', 'end_date': '2019-06-25'}
    ranges = [['2019-06-25', '2019-06-25']]

    vectors = Prepare.records([[random.random() for _ in range(dimension)]
                               for _ in range(1)])
    # ranges = [Prepare.range(**param)]
    LOGGER.info(ranges)
    param = {
        'table_name': table_name,
        'query_records': vectors,
        'top_k': 5,
        'query_ranges': ranges  # Not fully tested yet
    }
    status, result = milvus.search_vectors(**param)
    if status.OK():
        pprint(result)

    _, result = milvus.get_table_row_count(table_name)
    print('# Count: {}'.format(result))
    milvus.disconnect()
    def _create_table(_table_param):
        milvus = Milvus()
        milvus.connect(**server_config)
        status, ok = milvus.has_table(_table_name)
        if ok:
            print("Table {} found, now going to delete it".format(_table_name))
            status = milvus.delete_table(_table_name)
            if not status.OK():
                raise Exception("Delete table error")
            print("delete table {} successfully!".format(_table_name))
        time.sleep(5)

        status, ok = milvus.has_table(_table_name)
        if ok:
            raise Exception("Delete table error")

        status = milvus.create_table(param)
        if not status.OK():
            print("Create table {} failed".format(_table_name))

        milvus.disconnect()
Beispiel #14
0
def main():
    milvus = Milvus()

    # Connect to Milvus server
    # You may need to change _HOST and _PORT accordingly
    param = {'host': _HOST, 'port': _PORT}
    status = milvus.connect(**param)
    if status.OK():
        print("Server connected.")
    else:
        print("Server connect fail.")
        sys.exit(1)

    # Create collection demo_collection if it dosen't exist.
    collection_name = 'example_collection'

    status, ok = milvus.has_collection(collection_name)
    if not ok:
        param = {
            'collection_name': collection_name,
            'dimension': _DIM,
            'index_file_size': _INDEX_FILE_SIZE,  # optional
            'metric_type': MetricType.L2  # optional
        }

        milvus.create_collection(param)

    # Show collections in Milvus server
    _, collections = milvus.show_collections()

    # present collection info
    _, info = milvus.collection_info(collection_name)
    print(info)

    # Describe demo_collection
    _, collection = milvus.describe_collection(collection_name)
    print(collection)

    # 10000 vectors with 16 dimension
    # element per dimension is float32 type
    # vectors should be a 2-D array
    vectors = [[random.random() for _ in range(_DIM)] for _ in range(10000)]
    # You can also use numpy to generate random vectors:
    #     `vectors = np.random.rand(10000, 16).astype(np.float32)`

    # Insert vectors into demo_collection, return status and vectors id list
    status, ids = milvus.insert(collection_name=collection_name, records=vectors)

    # Flush collection  inserted data to disk.
    milvus.flush([collection_name])

    # Get demo_collection row count
    status, result = milvus.count_collection(collection_name)

    # create index of vectors, search more rapidly
    index_param = {
        'nlist': 2048
    }

    # Create ivflat index in demo_collection
    # You can search vectors without creating index. however, Creating index help to
    # search faster
    print("Creating index: {}".format(index_param))
    status = milvus.create_index(collection_name, IndexType.IVF_FLAT, index_param)

    # describe index, get information of index
    status, index = milvus.describe_index(collection_name)
    print(index)

    # Use the top 10 vectors for similarity search
    query_vectors = vectors[0:10]

    # execute vector similarity search
    search_param = {
        "nprobe": 16
    }
    param = {
        'collection_name': collection_name,
        'query_records': query_vectors,
        'top_k': 1,
        'params': search_param
    }
    print("Searching ... ")
    status, results = milvus.search(**param)

    if status.OK():
        # indicate search result
        # also use by:
        #   `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]`
        if results[0][0].distance == 0.0 or results[0][0].id == ids[0]:
            print('Query result is correct')
        else:
            print('Query result isn\'t correct')

    # print results
    print(results)

    # Delete demo_collection
    status = milvus.drop_collection(collection_name)

    # Disconnect from Milvus
    status = milvus.disconnect()
Beispiel #15
0
def main():
    milvus = Milvus(handler="HTTP")

    # Connect to Milvus server
    # You may need to change _HOST and _PORT accordingly
    param = {'host': _HOST, 'port': _PORT}
    status = milvus.connect(**param)
    if status.OK():
        print("Server connected.")
    else:
        print("Server connect fail.")
        sys.exit(1)

    # Create table demo_table if it dosen't exist.
    table_name = 'demo_tables'

    status, ok = milvus.has_table(table_name)
    if not ok:
        param = {
            'table_name': table_name,
            'dimension': _DIM,
            'index_file_size': _INDEX_FILE_SIZE,  # optional
            'metric_type': MetricType.L2  # optional
        }

        milvus.create_table(param)

    # Show tables in Milvus server
    _, tables = milvus.show_tables()

    # Describe demo_table
    _, table = milvus.describe_table(table_name)
    print(table)

    # 10000 vectors with 16 dimension
    # element per dimension is float32 type
    # vectors should be a 2-D array
    vectors = [[random.random() for _ in range(_DIM)] for _ in range(100000)]
    # You can also use numpy to generate random vectors:
    #     `vectors = np.random.rand(10000, 16).astype(np.float32).tolist()`

    # Insert vectors into demo_table, return status and vectors id list
    status, ids = milvus.insert(table_name=table_name, records=vectors)

    # Wait for 6 seconds, until Milvus server persist vector data.
    time.sleep(6)

    # Get demo_table row count
    status, result = milvus.count_table(table_name)

    # create index of vectors, search more rapidly
    index_param = {
        'index_type': IndexType.IVFLAT,  # choice ivflat index
        'nlist': 2048
    }

    # Create ivflat index in demo_table
    # You can search vectors without creating index. however, Creating index help to
    # search faster
    status = milvus.create_index(table_name, index_param)

    # describe index, get information of index
    status, index = milvus.describe_index(table_name)
    print(index)

    # Use the top 10 vectors for similarity search
    query_vectors = vectors[0:10]

    # execute vector similarity search
    param = {
        'table_name': table_name,
        'query_records': query_vectors,
        'top_k': 1,
        'nprobe': 16
    }
    status, results = milvus.search(**param)

    if status.OK():
        # indicate search result
        # also use by:
        #   `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]`
        if results[0][0].distance == 0.0 or results[0][0].id == ids[0]:
            print('Query result is correct')
        else:
            print('Query result isn\'t correct')

    # print results
    print(results)

    # Delete demo_table
    status = milvus.drop_table(table_name)

    # Disconnect from Milvus
    status = milvus.disconnect()
Beispiel #16
0
class SDKClient(object):
    def __init__(self, host=None, port=None):
        self.host = settings.MILVUS_SERVER_HOST if host is None else host
        self.port = settings.MILVUS_SERVER_PORT if port is None else port

    def init_client(self):
        self.client = Milvus()
        try:
            status = self.client.connect(host=self.host,
                                         port=self.port,
                                         timeout=settings.TIMEOUT)
        except Exception as exc:
            raise SDKClientConnectionException(str(exc))

        if status != Status.SUCCESS:
            raise SDKClientConnectionException(str(status))

    def __enter__(self):
        self.init_client()

    def __exit__(self, type, value, traceback):
        self.client.disconnect()
        self.client = None

    def search_vectors(self, table_id, query_records, topK):
        param = {
            'table_name': table_id,
            'query_records': query_records,
            'top_k': topK,
        }
        try:
            status, results = self.client.search_vectors(**param)
        except Exception as exc:
            raise SDKClientSearchVectorException(str(exc))

        if status != Status.SUCCESS:
            raise SDKClientSearchVectorException(str(status))

        return results

    def search_vectors_in_files(self,
                                table_id,
                                file_ids,
                                query_records,
                                topK,
                                query_ranges=None):
        try:
            status, results = self.client.search_vectors_in_files(
                table_id,
                file_ids,
                query_records,
                topK,
                query_ranges,
                raw=True)
        except Exception as exc:
            raise SDKClientSearchVectorException(str(exc))

        if status != Status.SUCCESS:
            raise SDKClientSearchVectorException(str(status))

        return results
Beispiel #17
0

data_name = 'D:/py_project/untitled1/polls/homework/vae/data.csv'
data = pd.read_csv(data_name)
vectors = data.values
# print(np.shape(data))
vector_ids = [id for id in range(60000)]

milvus = Milvus()
milvus.connect(host='localhost', port='19530')
collection_name = 'mnist'

param = {'collection_name': collection_name, 'dimension': 2, 'index_file_size': 1024, 'metric_type': MetricType.L2}
milvus.create_collection(param)
milvus.insert(collection_name=collection_name, records=vectors, ids=vector_ids)



# search_param = {'nprobe': 16}
# q_records = [[0.3364408, 0.20656677]]
# status, result = milvus.search(collection_name=collection_name, query_records=q_records, top_k=5, params=search_param)
#
# for row in result:
#
#     for item in row:
#         print("id={}, distance={}".format(item.id, item.distance))
#
#
milvus.disconnect()

Beispiel #18
0
def main():
    milvus = Milvus()

    # Print client version
    print('# Client version: {}'.format(milvus.client_version()))

    # Connect milvus server
    # Please change HOST and PORT to the correct one
    param = {'host': _HOST, 'port': _PORT}
    cnn_status = milvus.connect(**param)
    print('# Connect Status: {}'.format(cnn_status))

    # Check if connected
    # is_connected = milvus.connected
    print('# Is connected: {}'.format(milvus.connected))

    # Print milvus server version
    print('# Server version: {}'.format(milvus.server_version()))

    # Describe table
    table_name = 'table01'
    res_status, table = milvus.describe_table(table_name)
    print('# Describe table status: {}'.format(res_status))
    print('# Describe table:{}'.format(table))

    # Create table
    # Check if `table01` exists, if not, create a table `table01`
    dimension = 256
    if not table:
        param = {
            'table_name': table_name,
            'dimension': dimension,
            'index_type': IndexType.IDMAP,
            'store_raw_vector': False
        }

        res_status = milvus.create_table(Prepare.table_schema(**param))
        print('# Create table status: {}'.format(res_status))

    # Show tables and their description
    status, tables = milvus.show_tables()
    pprint(tables)

    # Add vectors
    # Prepare vector with 256 dimension
    vectors = Prepare.records([[random.random() for _ in range(dimension)]
                               for _ in range(20)])

    # Insert vectors into table 'table01'
    status, ids = milvus.add_vectors(table_name=table_name, records=vectors)
    print('# Add vector status: {}'.format(status))
    pprint(ids)

    # Search vectors
    # When adding vectors for the first time, server will take at least 5s to
    # persist vector data, so you have to wait for 6s after adding vectors for
    # the first time.
    print('# Waiting for 6s...')
    time.sleep(6)

    q_records = Prepare.records([[random.random() for _ in range(dimension)]
                                 for _ in range(2)])

    param = {
        'table_name': table_name,
        'query_records': q_records,
        'top_k': 10,
    }
    status, results = milvus.search_vectors(**param)
    print('# Search vectors status: {}'.format(status))
    pprint(results)

    # Get table row count
    status, result = milvus.get_table_row_count(table_name)
    print('# Status: {}'.format(status))
    print('# Count: {}'.format(result))

    # Disconnect
    status = milvus.disconnect()
    print('# Disconnect Status: {}'.format(status))
Beispiel #19
0
def main():
    milvus = Milvus()

    # Connect to Milvus server
    # You may need to change _HOST and _PORT accordingly
    param = {'host': _HOST, 'port': _PORT}
    status = milvus.connect(**param)

    # Create table demo_table if it dosen't exist.
    table_name = 'demo_table_01'

    if not milvus.has_table(table_name):
        param = {
            'table_name': table_name,
            'dimension': 16,
            'index_file_size': 1024,
            'metric_type': MetricType.L2
        }

        milvus.create_table(param)

    # Show tables in Milvus server
    _, tables = milvus.show_tables()

    # Describe demo_table
    _, table = milvus.describe_table(table_name)

    # 10 vectors with 16 dimension
    vectors = [
        [
            0.66, 0.01, 0.29, 0.64, 0.75, 0.94, 0.26, 0.79, 0.61, 0.11, 0.25,
            0.50, 0.74, 0.37, 0.28, 0.63
        ],
        [
            0.77, 0.65, 0.57, 0.68, 0.29, 0.93, 0.17, 0.15, 0.95, 0.09, 0.78,
            0.37, 0.76, 0.21, 0.42, 0.15
        ],
        [
            0.61, 0.38, 0.32, 0.39, 0.54, 0.93, 0.09, 0.81, 0.52, 0.30, 0.20,
            0.59, 0.15, 0.27, 0.04, 0.37
        ],
        [
            0.33, 0.03, 0.87, 0.47, 0.79, 0.61, 0.46, 0.77, 0.62, 0.70, 0.85,
            0.01, 0.30, 0.41, 0.74, 0.98
        ],
        [
            0.19, 0.80, 0.03, 0.75, 0.22, 0.49, 0.52, 0.91, 0.40, 0.91, 0.79,
            0.08, 0.27, 0.16, 0.07, 0.24
        ],
        [
            0.44, 0.36, 0.16, 0.88, 0.30, 0.79, 0.45, 0.31, 0.45, 0.99, 0.15,
            0.93, 0.37, 0.25, 0.78, 0.84
        ],
        [
            0.33, 0.37, 0.59, 0.66, 0.76, 0.11, 0.19, 0.38, 0.14, 0.37, 0.97,
            0.50, 0.08, 0.69, 0.16, 0.67
        ],
        [
            0.68, 0.97, 0.20, 0.13, 0.30, 0.16, 0.85, 0.21, 0.26, 0.17, 0.81,
            0.96, 0.18, 0.40, 0.13, 0.74
        ],
        [
            0.11, 0.26, 0.44, 0.91, 0.89, 0.79, 0.98, 0.91, 0.09, 0.45, 0.07,
            0.88, 0.71, 0.35, 0.97, 0.41
        ],
        [
            0.17, 0.54, 0.61, 0.58, 0.25, 0.63, 0.65, 0.71, 0.26, 0.80, 0.28,
            0.77, 0.69, 0.02, 0.63, 0.60
        ],
    ]

    # Insert vectors into demo_table
    status, ids = milvus.add_vectors(table_name=table_name, records=vectors)

    # Wait for 6 seconds, since Milvus server persist vector data every 5 seconds by default.
    # You can set data persist interval in Milvus config file.
    time.sleep(6)

    # Get demo_table row count
    status, result = milvus.get_table_row_count(table_name)

    # Use the 3rd vector for similarity search
    query_vectors = [vectors[3]]

    # execute vector similarity search
    param = {
        'table_name': table_name,
        'query_records': query_vectors,
        'top_k': 1,
        'nprobe': 16
    }
    status, results = milvus.search_vectors(**param)

    if results[0][0].distance == 0.0 or results[0][0].id == ids[3]:
        print('Query result is correct')
    else:
        print('Query result isn\'t correct')

    # Delete demo_table
    status = milvus.delete_table(table_name)

    # Disconnect from Milvus
    status = milvus.disconnect()
Beispiel #20
0
    # create `IVF_PQ` index
    status = client.create_index(collection_name, IndexType.IVF_PQ, index_param)
    if status.OK():
        print("Create index IVF_PQ successfully\n")
    else:
        print("Create index fail: ", status)

    # select top 10 vectors from inserted as query vectors
    query_vectors = vectors[:10]

    # specify search param
    search_param = {
        "nprobe": 10
    }

    # specify topk is 1, search approximate nearest 1 neighbor
    status, result = client.search(collection_name, 1, query_vectors, params=search_param)

    if status.OK():
        # show search result
        print("Search successfully. Result:\n", result)
    else:
        print("Search fail")

    # drop collection
    client.drop_collection(collection_name)

    # disconnect from server
    client.disconnect()
class ConnectionHandler:
    def __init__(self, uri):
        self.uri = uri
        self._retry_times = 0
        self._normal_times = 0
        self.thrift_client = Milvus()
        self.err_handlers = {}
        self.default_error_handler = None

    @contextmanager
    def connect_context(self):
        while self.can_retry:
            try:

                self.thrift_client.connect(uri=self.uri)
                break

            except Exception as e:
                handler = self.err_handlers.get(e.__class__, None)
                if handler:
                    handler(e)
                else:
                    raise e
        yield

        try:

            self.thrift_client.disconnect()
        except Exception:
            self.thrift_client = Milvus()

    def error_collector(self, func):
        @wraps(func)
        def inner(*args, **kwargs):
            try:
                return func(*args, **kwargs)
            except (ThriftException) as e:
                handler = self.err_handlers.get(e.__class__, None)
                if handler:
                    handler(e)
                else:
                    raise e
                LOGGER.error(e)
        return inner

    def connect(self, func, handle_error=True):
        @wraps(func)
        def inner(*args, **kwargs):
            with self.connect_context():
                if handle_error:
                    try:
                        return func(*args, **kwargs)
                    except ThriftException as e:
                        handler = self.err_handlers.get(e.__class__, None)
                        if handler:
                            handler(e)
                        else:
                            raise e
                else:
                    return func(*args, **kwargs)

        return inner

    @property
    def client(self):
        return self.thrift_client

    def reconnect(self, uri=None):
        self.uri = uri if uri else self.uri
        self.thrift_client = Milvus()

    @property
    def can_retry(self):
        if self._normal_times >= settings.THRIFTCLIENT_NORMAL_TIME:
            self._retry_times = self._retry_times - 1 if self._retry_times > 0 else 0
            self._normal_times -= settings.THRIFTCLIENT_NORMAL_TIME
        return self._retry_times <= settings.THRIFTCLIENT_RETRY_TIME

    def err_handler(self, exception):
        if inspect.isclass(exception) and issubclass(exception, Exception):
            def wrappers(func):
                self.err_handlers[exception] = func
                return func
            return wrappers
        else:
            self.default_error_handler = exception
            return exception
Beispiel #22
0
    def search(self, img_id):

        x_real, x_fake, cost, optimizer, z_noise, z_real, guessed_z = self.build_model()
        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            model_file = tf.train.latest_checkpoint('D:/py_project/untitled1/polls/homework/vae/ckpt/')
            # 加载参数
            saver.restore(sess, model_file)

            # img_id = 'test_img_0.png'
            search_img_path = "D:/py_project/untitled1/polls/static/polls/homework/"+img_id
            search_img = cv2.imread(search_img_path, cv2.IMREAD_GRAYSCALE)
            search_img = np.reshape(search_img, [-1, 784])/255

            # 获取 要搜索图片的特征向量

            test_z = sess.run(guessed_z, feed_dict={x_real: search_img})

            print(img_id, test_z)

            # 连接milvus数据库 并进行查询
            milvus = Milvus()
            milvus.connect(host='localhost', port='19530')
            collection_name = 'mnist'
            search_param = {'nprobe': 16}
            status, result = milvus.search(collection_name=collection_name, query_records=test_z, top_k=100,
                                           params=search_param)
            milvus.disconnect()

            # index = []
            # for row in result:
            #     for item in row:
            #         index.append(item.id)
            # batch_real = self.train_images[index]
            # r, c = 10, 10
            # fig, axs = plt.subplots(r, c)
            # cnt = 0
            # for p in range(r):
            #     for q in range(c):
            #         axs[p, q].imshow(np.reshape(batch_real[cnt], (28, 28)), cmap='gray')
            #         axs[p, q].axis('off')
            #         cnt += 1
            # print(os.path)
            # fig.savefig("D:/py_project/untitled1/polls/static/polls/homework/d_real.png")
            # plt.show()
            # plt.close()

            # 保存查询结果
            for row in result:
                for i, item in enumerate(row):
                    # a = 1
                    # print(item.id)
                    result_img_path = "D:/py_project/untitled1/polls/static/polls/homework/search_result/result_"
                    batch_test_pic = self.train_images[item.id]
                    pic_name = result_img_path + str(i) + ".png"
                    cv2.imwrite(pic_name, batch_test_pic)

            print('search success!')

            return True