Exemple #1
0
def predict_xl(text, model: Transformer, device, is_beam_search=False):
    input_id = get_input_char_index(text)
    # input_length = torch.LongTensor([len(input_id)]).to(device)
    input_tensor = torch.LongTensor([input_id]).to(device)
    batch_size, src_len = input_tensor.shape
    trg = input_tensor.new_full((batch_size, 1), model.sos_idx)
    src_mask, trg_mask = model.make_masks(input_tensor, trg)

    if is_beam_search == False:
        # while True:
        encoder_output = model.encoder(input_tensor, src_mask)
        step = 0
        result = []
        while step < 200:
            # print(step)
            output = model.decoder(trg, encoder_output, trg_mask, src_mask)
            output = torch.argmax(output[:, -1], dim=1)
            result.append(output.item())
            if output.numpy()[0] == EOS_IDX:
                break
            output = output.unsqueeze(1)
            trg = torch.cat((trg, output), dim=1)
            src_mask, trg_mask = model.make_masks(input_tensor, trg)
            step += 1
        # outpu_tensor = torch.argmax(output.squeeze(1), 1)
        ouput_str = get_output_char(result)
        return ouput_str
    else:
        
        target = beam_search.beam_decode(input_tensor, model, beam_with=5)
        print(target)
        print(len(target[0][0]))
        ouput_str = get_output_char(target[0][0][1:])
        return ouput_str
Exemple #2
0
def beam_decode(src_input, model: Transformer, beam_with=3, topk=1):
    '''

    :param src_input: 输入的 char_id,shape:[seq_pad_length, batch_size]
    :param src_input_length: src_input的非pad长度 shape:[batch_size]
    :param model: Seq2Seq模型
    :param beam_with:   beam search 宽度
    :param topk:    生成topk个句子
    :return:
    '''
    batch_size, src_len = src_input.shape
    trg = src_input.new_full((batch_size, 1), model.sos_idx)

    src_mask, trg_mask = model.make_masks(src_input, trg)
    encoder_outputs = model.encoder(src_input, src_mask)
    decode_result = []
    for batch_index in range(batch_size):
        # 当前句子的编码器输出
        encoder_output_current = encoder_outputs[batch_index, :, :].unsqueeze(
            0)
        input_tensor = src_input[batch_index, :].unsqueeze(0)
        trg_current = trg[batch_index, :].unsqueeze(0)
        word_idx = trg_current[batch_index, -1]

        #
        endnodes = []
        number_required = min((topk + 1), topk - len(endnodes))
        # 优先队列
        nodes_queue = PriorityQueue()
        node = BeamSearchNode(trg_current, None, word_idx, 0, 1)

        # 将node加入到优先队列
        nodes_queue.put((-node.eval(), node))
        q_size = 1

        # 开始 beam search
        while True:
            if q_size > 200:
                break
            # 获得 best_node
            score, n = nodes_queue.get()
            decoder_input = n.word_index
            trg = n.decoder_hidden
            src_mask, trg_mask = model.make_masks(input_tensor, trg)
            if n.word_index.item(
            ) == model.eos_idx and n.previous_node != None:
                endnodes.append((score, n))
                if len(endnodes) > number_required:
                    break
                else:
                    continue

            # 解码
            output = model.decoder(trg, encoder_output_current, trg_mask,
                                   src_mask)

            # result.append(output.item())
            # if output.numpy()[0] == EOS_IDX:
            #     break

            # 获得 beam_with个可能
            log_prob, indexs = torch.topk(output, beam_with)
            # output = torch.argmax(output[:, -1], dim=1)
            # output = output.unsqueeze(1)

            #

            next_nodes = []

            for new_k in range(beam_with):
                decoded_t = indexs[0][-1][new_k].view(-1)
                log_p = log_prob[0][-1][new_k].item()
                # output = output[indexs[0][0][new_k]]
                output = decoded_t.unsqueeze(1)
                trg_tmp = torch.cat((trg, output), dim=1)
                node = BeamSearchNode(trg_tmp, n, decoded_t,
                                      log_p + n.log_prob, n.length + 1)
                score = -node.eval()
                next_nodes.append((score, node))

            for i in range(len(next_nodes)):
                score, nn = next_nodes[i]
                nodes_queue.put((score, nn))

            q_size += len(next_nodes) - 1

        if len(endnodes) == 0:
            endnodes = [nodes_queue.get() for _ in range(topk)]

        utterances = []
        i = 0
        for score, n in sorted(endnodes, key=operator.itemgetter(0)):
            if i >= topk:
                break
            unterance = []
            unterance.append(n.word_index.numpy()[0])

            # 回溯
            while n.previous_node != None:
                n = n.previous_node
                unterance.append(n.word_index.item())

            unterance = unterance[::-1]
            utterances.append(unterance)
            i += 1
        decode_result.append(utterances)
        return decode_result