Esempio n. 1
0
        def gen():
            while True:
                tokenizer = FullTokenizer(vocab_file=os.path.join(self.args.bert_model_dir, 'vocab_ner.txt'))
                # Windows does not support logger in MP environment, thus get a new logger
                # inside the process for better compatibility
                logger = set_logger(colored('WORKER-%d' , 'yellow'), self.verbose)


                logger.info('ready and listening!')

                msg = ner_bert.msg
                # check if msg is a list of list, if yes consider the input is already tokenized
                # 对接收到的字符进行切词,并且转化为id格式
                # logger.info('get msg:%s, type:%s' % (msg[0], type(msg[0])))
                is_tokenized = all(isinstance(el, list) for el in msg)
                logger.info(is_tokenized)
                tmp_f = list(extract_features.ner_convert_lst_to_features(msg, self.max_seq_len, tokenizer, logger,
                                                                          is_tokenized, self.mask_cls_sep))


                print("tokens:",[f.tokens for f in tmp_f])
                # print("input_ids:",[f.input_ids for f in tmp_f])
                # print("--------------------------------")


                yield {
                    # 'client_id': client_id,
                    'input_ids': [f.input_ids for f in tmp_f],
                    'input_mask': [f.input_mask for f in tmp_f],
                    'input_type_ids': [f.input_type_ids for f in tmp_f],
                    # "tokens" : [f.tokens for f in tmp_f]
                }
Esempio n. 2
0
        def gen():
            while True:
                # if bert.first == True:
                #     break
                # tokenizer = FullTokenizer(vocab_file=os.path.join(self.args.bert_model_dir, 'vocab.txt'))
                tokenizer = FullTokenizer(vocab_file="D:\\LiuXianXian\\pycharm--code\\flask4bert\\BertModel\\checkpoints\\vocab_classify.txt")
                # Windows does not support logger in MP environment, thus get a new logger
                # inside the process for better compatibility
                logger = set_logger(colored('WORKER-%d' , 'yellow'), self.verbose)
                print("22222222222222222222222222222222222222222222222222222222222222222222222222")



                # logger.info(msg)
                # logger.info("message===="+"  ".join(msg))
                # logger.info('new job\tsocket: %d\tsize: %d\tclient: %s' % (sock_idx, len(msg), client_id))
                # check if msg is a list of list, if yes consider the input is already tokenized
                # 对接收到的字符进行切词,并且转化为id格式
                # logger.info('get msg:%s, type:%s' % (msg[0], type(msg[0])))
                msg = "查尔斯·阿兰基斯(Charles Aránguiz),1989年4月17日出生于智利圣地亚哥,智利职业足球运动员,司职中场,效力于德国足球甲级联赛勒沃库森足球俱乐部"
                is_tokenized = all(isinstance(el, list) for el in bert.msg)
                logger.info(is_tokenized)
                tmp_f = list(extract_features.convert_lst_to_features(bert.msg, self.max_seq_len, tokenizer, logger,
                                                                      is_tokenized, self.mask_cls_sep))
                print([f.input_ids for f in tmp_f])
                client_id ="1"

                yield {
                    # 'client_id': client_id,
                    'input_ids': [f.input_ids for f in tmp_f],
                    'input_mask': [f.input_mask for f in tmp_f],
                    'input_type_ids': [f.input_type_ids for f in tmp_f]
                }
Esempio n. 3
0
    def run_classify(self,r):
        # Windows does not support logger in MP environment, thus get a new logger
        # inside the process for better compatibility
        logger = set_logger(colored('WORKER-%d' , 'yellow'), self.verbose)
        logger.info('use device %s, load graph from %s' %
                    ('cpu' if len(self.device_map) <= 0 else ('gpu: %s' % ",".join(self.device_map)), self.graph_path))

        tf = import_tf(self.device_map, self.verbose, use_fp16=self.use_fp16)
        # estimator = self.get_estimator(tf)


        # for sock, addr in zip(receivers, self.worker_address):
        #     sock.connect(addr)

        # sink.connect(self.sink_address)

        predict_drop_remainder = False
        predict_file = "D:\\LiuXianXian\\pycharm--code\\flask4bert\\BertModel\\predictData\\predict.tf_record"
        # predict_input_fn = self.file_based_input_fn_builder(
        #     input_file=predict_file,
        #     seq_length=128,
        #     label_length=49,
        #     is_training=False,
        #     drop_remainder=predict_drop_remainder)
        # msg = "查尔斯·阿兰基斯(Charles Aránguiz),1989年4月17日出生于智利圣地亚哥,智利职业足球运动员,司职中场,效力于德国足球甲级联赛勒沃库森足球俱乐部"
        # r = estimator.predict(input_fn=self.input_fn_builder(msg))

        # aa = self.input_fn_builder()
        prediction=next(r)
        pred_label_result = []
        pred_score_result = []
        for index, class_probabilitys in enumerate(prediction["encodes"]):
            single_result = []
            single_socre_result = []
            # logger.info(prediction)
            pro_sum = 0.0
            for idx,class_probability in enumerate(class_probabilitys):
                pro_sum=pro_sum+class_probability
                if class_probability > 0.5:
                    single_result.append(self.id2label.get(idx, -1))
                    single_socre_result.append(class_probability)
            # print(pro_sum)
            pred_label_result.append(single_result)
            pred_score_result.append(single_socre_result)
            # pred_label_result = [self.id2label.get(x, -1) for x in r['encodes'] ]

            # pred_score_result = r['score'].tolist()
        to_client = {'pred_label': pred_label_result, 'score': pred_score_result}
        logger.info(to_client)


        return pred_label_result
Esempio n. 4
0
    def run_ner(self,r):
        # Windows does not support logger in MP environment, thus get a new logger
        # inside the process for better compatibility
        logger = set_logger(colored('WORKER-%d' , 'yellow'), self.verbose)

        # logger.info('use device %s, load graph from %s' %
        #             ('cpu' if len(self.device_map) <= 0 else ('gpu: %s' % ",".join(self.device_map)), self.graph_path))

        tf = import_tf(self.device_map, self.verbose, use_fp16=self.use_fp16)

        prediction = next(r)

        # logger.info(prediction["predicate_probabilities"])
        # logger.info(prediction["predicate_probabilities"].shape)
        # logger.info(prediction["predicate_index"])
        # logger.info(prediction["token_label_probabilities"])
        # logger.info(prediction["token_label_probabilities"].shape)
        # logger.info(prediction["token_label_index"])

        predicate_index = prediction["predicate_index"]
        token_label_index = prediction["token_label_index"]
        # logger.info(self.predicate_id2label)

        predicate_result = []
        for tmp_predicate_index in predicate_index:
            tmp_result = []
            tmp_result.append(self.predicate_id2label.get(tmp_predicate_index, -1))
            predicate_result.append(tmp_result)
        # logger.info(predicate_result)

        token_label_result, pred_ids_result = self.ner_result_to_json(token_label_index, self.token_id2label)
        # logger.info(token_label_result)

        result = []
        for index, tmp_token_label_result in enumerate(token_label_result):
            # logger.info(predicate_result[index])
            tmp_token_label_result.append(predicate_result[index])
            result.append(tmp_token_label_result)

        result_dict={"pred_label":predicate_result,"token_label_result":token_label_result}
        result_dict={"pred_label":predicate_result}
        logger.info(result)
        return token_label_result
Esempio n. 5
0
 def __init__(self, args, device_map, graph_path, mode, id2label,predicate_id2label, token_id2label):
     super().__init__()
     # self.worker_id = id
     self.device_map = device_map
     self.logger = set_logger(colored('WORKER-%d' , 'yellow'), args.verbose)
     self.max_seq_len = args.max_seq_len
     self.mask_cls_sep = args.mask_cls_sep
     self.daemon = True
     # self.exit_flag = multiprocessing.Event()
     # self.worker_address = worker_address_list
     # self.num_concurrent_socket = len(self.worker_address)
     # self.sink_address = sink_address
     self.prefetch_size = args.prefetch_size if len(self.device_map) > 0 else None  # set to zero for CPU-worker
     self.gpu_memory_fraction = args.gpu_memory_fraction
     self.verbose = args.verbose
     self.graph_path = graph_path
     self.use_fp16 = args.fp16
     self.args = args
     self.mode = mode
     self.id2label = id2label
     self.predicate_id2label = predicate_id2label
     self.token_id2label = token_id2label