示例#1
0
    def recall(self, request, context):
        '''
    message RecallRequest{
        string log_id = 1;
        user_info.UserInfo user_info = 2;
        string recall_type= 3;
        uint32 request_num= 4;
    }

    message RecallResponse{
        message Error {
            uint32 code = 1;
            string text = 2;
        }
        message ScorePair {
            string nid = 1;
            float score = 2;
        };
        Error error = 1;
        repeated ScorePair score_pairs = 2;
    }
        '''
        recall_res = recall_pb2.RecallResponse()
        user_vector = self.get_user_vector(request.user_info)

        query_hybrid = {
            "bool": {
                "must": [{
                    "vector": {
                        "embedding": {
                            "topk": 100,
                            "query": [user_vector],
                            "metric_type": "L2"
                        }
                    }
                }]
            }
        }

        results = self.milvus_client.search(self.collection_name,
                                            query_hybrid,
                                            fields=["embedding"])
        for entities in results:
            if len(entities) == 0:
                recall_res.error.code = 500
                recall_res.error.text = "Recall server get milvus fail. ({})".format(
                    str(request))
                return recall_res
            for topk_film in entities:
                current_entity = topk_film.entity
                score_pair = recall_res.score_pairs.add()
                score_pair.nid = str(topk_film.id)
                score_pair.score = float(topk_film.distance)
        recall_res.error.code = 200
        return recall_res
示例#2
0
    def recall(self, request, context):
        '''
    message RecallRequest{
        string log_id = 1;
        user_info.UserInfo user_info = 2;
        string recall_type= 3;
        uint32 request_num= 4;
    }

    message RecallResponse{
        message Error {
            uint32 code = 1;
            string text = 2;
        }
        message ScorePair {
            string nid = 1;
            float score = 2;
        };
        Error error = 1;
        repeated ScorePair score_pairs = 2;
    }
        '''
        recall_res = recall_pb2.RecallResponse()
        user_vector = self.get_user_vector(request.user_info)

        status, results = self.milvus_client.search(
            collection_name=self.collection_name,
            vectors=[user_vector],
            partition_tag="Movie")
        for entities in results:
            if len(entities) == 0:
                recall_res.error.code = 500
                recall_res.error.text = "Recall server get milvus fail. ({})".format(
                    str(request))
                return recall_res
            for topk_film in entities:
                # current_entity = topk_film.entity
                score_pair = recall_res.score_pairs.add()
                score_pair.nid = str(topk_film.id)
                score_pair.score = float(topk_film.distance)
        recall_res.error.code = 200
        return recall_res
示例#3
0
    def recall(self, request, context):
        '''
    message RecallRequest{
        string log_id = 1;
        user_info.UserInfo user_info = 2;
        string recall_type= 3;
        uint32 request_num= 4;
    }

    message RecallResponse{
        message Error {
            uint32 code = 1;
            string text = 2;
        }
        message ScorePair {
            string nid = 1;
            float score = 2;
        };
        Error error = 1;
        repeated ScorePair score_pairs = 2;
    }
        '''
        recall_res = recall_pb2.RecallResponse()
        user_id = request.user_info.user_id
        redis_res = self.redis_cli.lrange("{}##recall".format(user_id), 0, 200)
        if redis_res is None:
            recall_res.error.code = 500
            recall_res.error.text = "Recall server get user_info from redis fail. ({})".format(
                str(request))
            return recall_res
            #raise ValueError("UM server get user_info from redis fail. ({})".format(str(request)))
        recall_res.error.code = 200
        ## FIX HERE
        for item in redis_res:
            item_id, score = item.split("#")[0], item.split("#")[1]
            score_pair = recall_res.score_pairs.add()
            score_pair.nid = item_id
            score_pair.score = float(score)
        return recall_res