Exemplo n.º 1
0
    def generate(self, history):
        try:
            print(history)

            history_ids = [self.tokenizer.encode(v) for v in history]
            input_ids = [self.tokenizer.cls_token_id]
            for history_id, history_utr in enumerate(history_ids):
                input_ids.extend(history_utr)
                input_ids.append(self.tokenizer.sep_token_id)

            # print(history_ids)
            # print(input_ids)
            input_ids = [copy.deepcopy(input_ids) for _ in range(self.batch_size)]
            curr_input_tensors = torch.tensor(input_ids).long().to(self.device)


            candidate_responses = self._make_dialogue_response(curr_input_tensors)
            assert len(candidate_responses) >= 1
            best_response_ids = self._make_mmi_output(candidate_responses,history_ids)
            best_response_chars = self.tokenizer.convert_ids_to_tokens(best_response_ids)
            return best_response_chars
        except Exception as e:
            LOGGER.error("FAIL GEN: {}".format(str(e)))
            traceback.print_exc()
            return []
Exemplo n.º 2
0
    def __init__(self):
        try:
            self.device = 'cuda' if config_model.use_cuda else 'cpu'
            LOGGER.info('using device: {}'.format(self.device))
            if self.device == 'cuda':
                os.environ["CUDA_VISIBLE_DEVICES"] = config_model.device_nums
            self.tokenizer = BertTokenizer(config_model.vocab_path)

            # dialogue model
            self.dialogue_model = GPT2LMHeadModel.from_pretrained(config_model.dialogue_model_path)
            self.dialogue_model.to(self.device)
            self.dialogue_model.eval()

            # mmi model
            self.mmi_model = GPT2LMHeadModel.from_pretrained(config_model.mmi_model_path)
            self.mmi_model.to(self.device)
            self.dialogue_model.eval()

            self.max_sequence_len = config_model.max_len
            self.batch_size = config_model.batch_size
            self.repetition_penalty = config_model.repetition_penalty
            self.temperature = config_model.temperature
            self.debug = config_model.debug
            self.topk = config_model.topk
            self.topp = config_model.topp


        except Exception as e:
            LOGGER.error("FAIL INIT: {}".format(str(e)))
            traceback.print_exc()
            sys.exit(-1)
Exemplo n.º 3
0
 def get_history(self,session_id):
     try:
         if session_id not in self.history_dict or "history" not in self.history_dict[session_id]:
             return []
         else:
             return self.history_dict[session_id]["history"][-self.max_history_len:]
     except Exception as e:
         LOGGER.error("FAIL update history: session_id: {}, error: {}".format(str(session_id), str(e)))
         return []
Exemplo n.º 4
0
 def update_history(self,session_id, new_input_text):
     try:
         if session_id not in self.history_dict:
             self.history_dict[session_id] = {
                 "history": [],
                 "modified_time": time.time()
             }
         self.history_dict[session_id]["history"].append(new_input_text)
         self.history_dict[session_id]["modified"] = time.time()
         return True
     except Exception as e:
         LOGGER.error("FAIL update history: session_id: {}, error: {}".format(str(session_id), str(e)))
         return False
Exemplo n.º 5
0
 def _make_dialogue_response(self, input_tensors):
     try:
         generated = []
         finish_set = set()  # 标记是否所有response均已生成结束,若第i个response生成结束,即生成了sep_token_id,则将i放入finish_set
         # 最多生成max_len个token
         for _ in range(self.max_sequence_len):
             outputs = self.dialogue_model(input_ids=input_tensors)
             next_token_logits = outputs[0][:, -1, :]
             # 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率
             for index in range(self.batch_size):
                 for token_id in set([token_ids[index] for token_ids in generated]):
                     next_token_logits[index][token_id] /= self.repetition_penalty
             next_token_logits = next_token_logits / self.temperature
             # 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token
             for next_token_logit in next_token_logits:
                 next_token_logit[self.tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
             filtered_logits = self._top_k_top_p_filtering(next_token_logits, top_k=self.topk, top_p=self.topp)
             # torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标
             next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
             # 判断是否有response生成了[SEP],将已生成了[SEP]的resposne进行标记
             for index, token_id in enumerate(next_token[:, 0]):
                 if token_id == self.tokenizer.sep_token_id:
                     finish_set.add(index)
             # 检验是否所有的response均已生成[SEP]
             finish_flag = True  # 是否所有的response均已生成[SEP]的token
             for index in range(self.batch_size):
                 if index not in finish_set:  # response批量生成未完成
                     finish_flag = False
                     break
             if finish_flag:
                 break
             generated.append([token.item() for token in next_token[:, 0]])
             # 将新生成的token与原来的token进行拼接
             input_tensors = torch.cat((input_tensors, next_token), dim=-1)
         candidate_responses = []  # 生成的所有候选response
         for batch_index in range(self.batch_size):
             response = []
             for token_index in range(len(generated)):
                 if generated[token_index][batch_index] != self.tokenizer.sep_token_id:
                     response.append(generated[token_index][batch_index])
                 else:
                     break
             candidate_responses.append(response)
         return candidate_responses
     except Exception as e:
         LOGGER.error("FAIL make response: {}".format(str(e)))
         traceback.print_exc()
         return []
Exemplo n.º 6
0
    def post(self):
        response = {'status': 0, 'data': {}, 'message': 'fail'}
        try:
            session_id = self.get_argument("sessionId")
            input_text = self.get_argument("text")
        except Exception as e:
            LOGGER.error("FAIL receive args: {}".format(str(e)))
            response['message'] = str(e)
            self.finish(response)
            return

        try:
            st = time.time()
            session_id = int(session_id)
            keeper_partition = session_id % config_instance.num_keepers
            keepers[keeper_partition].update_history(session_id=session_id,
                                                     new_input_text=input_text)
            history = keepers[keeper_partition].get_history(
                session_id=session_id)
            generate_chars = worker.generate(history)
            print(generate_chars)
            if len(generate_chars) == 0:
                response['message'] = "fail generate response text"
                self.finish(response)
            generate = "".join(generate_chars)
            keepers[keeper_partition].update_history(session_id=session_id,
                                                     new_input_text=generate)
            body_info = {
                'sessionId': session_id,
                'input': input_text,
                'output': generate
            }
            print(body_info)
            LOGGER.info(
                "receive: session_id: {}, input_text: {}, back: {}, cost: {} ms"
                .format(str(session_id), input_text, json.dumps(body_info),
                        (time.time() - st) * 1000))
            response['data'] = body_info
            response['status'] = 1
            response['message'] = 'success'
            self.finish(response)

        except Exception as e:
            LOGGER.error("FAIL make resonse: {}".format(str(e)))
            response['message'] = str(e)
            self.finish(response)
        return
Exemplo n.º 7
0
    def run(self):
        while True:
            time.sleep(1800)
            cur_update_time = time.time()
            expire_list = []

            for key in self.history_dict.keys():
                try:
                    if not "history" in self.history_dict[key]:
                        self.history_dict.pop(key)
                        expire_list.append(json.dumps({
                            "session_id":key,
                            "history":[],
                            "last_modified":time.time() - 1800
                        }))
                    if "modified_time" in self.history_dict[key] and type(self.history_dict[key]["modified_time"]) == float:
                        if cur_update_time - self.history_dict[key]["modified_time"] > 1800:
                            self.history_dict.pop(key)
                            expire_list.append(json.dumps({
                                "session_id":key,
                                "history": self.history_dict[key]["history"],
                                "last_modified": self.history_dict[key]["modified_time"]
                            }))
                    else:
                        self.history_dict.pop(key)
                        expire_list.append(json.dumps({
                                "session_id":key,
                                "history": self.history_dict[key]["history"],
                                "last_modified": time.time() - 1800
                        }))
                except Exception as e:
                    LOGGER.error("bad exec: {}, reason: {}".format(str(key),str(e)))
                    traceback.print_exc()
                    continue


            with open(self.expire_save_path, 'a') as fw:
                for expire_session in expire_list:
                    fw.write(expire_session+'\n')

            with open(self.history_save_path, 'w') as fw:
                json.dump(self.history_dict,fw)
Exemplo n.º 8
0
 def encode_to_ids(self, text):
     try:
         return self.tokenizer.encode(text)
     except Exception as e:
         LOGGER.error("FAIL INIT: {}".format(str(e)))
         traceback.print_exc()