def predict_xl(text, model: Transformer, device, is_beam_search=False): input_id = get_input_char_index(text) # input_length = torch.LongTensor([len(input_id)]).to(device) input_tensor = torch.LongTensor([input_id]).to(device) batch_size, src_len = input_tensor.shape trg = input_tensor.new_full((batch_size, 1), model.sos_idx) src_mask, trg_mask = model.make_masks(input_tensor, trg) if is_beam_search == False: # while True: encoder_output = model.encoder(input_tensor, src_mask) step = 0 result = [] while step < 200: # print(step) output = model.decoder(trg, encoder_output, trg_mask, src_mask) output = torch.argmax(output[:, -1], dim=1) result.append(output.item()) if output.numpy()[0] == EOS_IDX: break output = output.unsqueeze(1) trg = torch.cat((trg, output), dim=1) src_mask, trg_mask = model.make_masks(input_tensor, trg) step += 1 # outpu_tensor = torch.argmax(output.squeeze(1), 1) ouput_str = get_output_char(result) return ouput_str else: target = beam_search.beam_decode(input_tensor, model, beam_with=5) print(target) print(len(target[0][0])) ouput_str = get_output_char(target[0][0][1:]) return ouput_str
def beam_decode(src_input, model: Transformer, beam_with=3, topk=1): ''' :param src_input: 输入的 char_id,shape:[seq_pad_length, batch_size] :param src_input_length: src_input的非pad长度 shape:[batch_size] :param model: Seq2Seq模型 :param beam_with: beam search 宽度 :param topk: 生成topk个句子 :return: ''' batch_size, src_len = src_input.shape trg = src_input.new_full((batch_size, 1), model.sos_idx) src_mask, trg_mask = model.make_masks(src_input, trg) encoder_outputs = model.encoder(src_input, src_mask) decode_result = [] for batch_index in range(batch_size): # 当前句子的编码器输出 encoder_output_current = encoder_outputs[batch_index, :, :].unsqueeze( 0) input_tensor = src_input[batch_index, :].unsqueeze(0) trg_current = trg[batch_index, :].unsqueeze(0) word_idx = trg_current[batch_index, -1] # endnodes = [] number_required = min((topk + 1), topk - len(endnodes)) # 优先队列 nodes_queue = PriorityQueue() node = BeamSearchNode(trg_current, None, word_idx, 0, 1) # 将node加入到优先队列 nodes_queue.put((-node.eval(), node)) q_size = 1 # 开始 beam search while True: if q_size > 200: break # 获得 best_node score, n = nodes_queue.get() decoder_input = n.word_index trg = n.decoder_hidden src_mask, trg_mask = model.make_masks(input_tensor, trg) if n.word_index.item( ) == model.eos_idx and n.previous_node != None: endnodes.append((score, n)) if len(endnodes) > number_required: break else: continue # 解码 output = model.decoder(trg, encoder_output_current, trg_mask, src_mask) # result.append(output.item()) # if output.numpy()[0] == EOS_IDX: # break # 获得 beam_with个可能 log_prob, indexs = torch.topk(output, beam_with) # output = torch.argmax(output[:, -1], dim=1) # output = output.unsqueeze(1) # next_nodes = [] for new_k in range(beam_with): decoded_t = indexs[0][-1][new_k].view(-1) log_p = log_prob[0][-1][new_k].item() # output = output[indexs[0][0][new_k]] output = decoded_t.unsqueeze(1) trg_tmp = torch.cat((trg, output), dim=1) node = BeamSearchNode(trg_tmp, n, decoded_t, log_p + n.log_prob, n.length + 1) score = -node.eval() next_nodes.append((score, node)) for i in range(len(next_nodes)): score, nn = next_nodes[i] nodes_queue.put((score, nn)) q_size += len(next_nodes) - 1 if len(endnodes) == 0: endnodes = [nodes_queue.get() for _ in range(topk)] utterances = [] i = 0 for score, n in sorted(endnodes, key=operator.itemgetter(0)): if i >= topk: break unterance = [] unterance.append(n.word_index.numpy()[0]) # 回溯 while n.previous_node != None: n = n.previous_node unterance.append(n.word_index.item()) unterance = unterance[::-1] utterances.append(unterance) i += 1 decode_result.append(utterances) return decode_result