Exemplo n.º 1
0
    def predict(self, data, cangtou, predict_param):
        hard_rhyme = predict_param['hard_rhyme']
        with torch.no_grad():
            encoder_hidden = self.encoder.initHidden(1)

            input_sentence, input_length = data
            encoder_outputs, encoder_hidden = self.encoder(
                input_sentence, [input_length], encoder_hidden)

            # 将encoder_outputs padding至INPUT_MAX_LENGTH 因为attention中已经固定此维度大小为INPUT_MAX_LENGTH
            encoder_outputs_padded = torch.zeros(1,
                                                 self.input_max_len,
                                                 self.encoder.hidden_size,
                                                 device=device)
            for b in range(1):
                for ei in range(input_length):
                    encoder_outputs_padded[b, ei] = encoder_outputs[b, ei]

            decoder_input = torch.tensor([[SOS_token]],
                                         device=device)  # 第一个input是START
            decoder_hidden = encoder_hidden  # Use last hidden state from encoder to start decoder

            sen_len = 7  # 暂时
            sen_num = 4
            decoded_words = []

            for i in range(sen_num):
                for j in range(sen_len):
                    position = torch.tensor([[i, j]],
                                            dtype=torch.float,
                                            device=device)
                    decoder_output, decoder_hidden, decoder_attention = self.decoder(
                        decoder_input, decoder_hidden, encoder_outputs_padded,
                        position)
                    if j == 0 and cangtou and i < len(cangtou):
                        top_word = cangtou[i]
                        top_id = torch.LongTensor(
                            [word2id.get(top_word, vocab_size - 1)])
                    else:
                        top_id, top_word = get_next_word(decoder_output.data,
                                                         decoded_words,
                                                         hard_rhyme=hard_rhyme)
                        if top_word == 'N':
                            print('cannot meet requirements')
                            break
                    decoded_words.append(top_word)
                    decoder_input = top_id.reshape(
                        (1, 1)).detach()  # detach from history as input

                position = torch.tensor([[i, 7]],
                                        dtype=torch.float,
                                        device=device)
                tmp_decoder_output, tmp_decoder_hidden, tmp_decoder_attention = self.decoder(
                    decoder_input, decoder_hidden, encoder_outputs_padded,
                    position)
                decoder_hidden = tmp_decoder_hidden
                decoder_input = torch.tensor([[2]], device=device)  # '/'作为输入

        return decoded_words
Exemplo n.º 2
0
    def predict(self, data, cangtou, predict_param):
        with torch.no_grad():
            encoder_hidden = self.encoder.initHidden(1)

            sen_len = 7  # 暂时
            sen_num = 4
            
            input_sentence, input_length = data
            decoded_words = [id2word[str(id.item())] for id in input_sentence[0]]
            decoder_input = torch.tensor([[2]], device=device)  # 第一个input是/ # Jun24
           
            for i in range(sen_num-1):
                encoder_outputs, encoder_hidden = self.encoder(input_sentence, [input_length], encoder_hidden)

            # 将encoder_outputs padding至INPUT_MAX_LENGTH 因为attention中已经固定此维度大小为INPUT_MAX_LENGTH
                encoder_outputs_padded = torch.zeros(1, self.input_max_len, self.encoder.hidden_size,
                                                 device=device)
                for b in range(1):
                    for ei in range(input_length):
                        encoder_outputs_padded[b, ei] = encoder_outputs[b, ei]

                decoder_hidden = encoder_hidden  # Use last hidden state from encoder to start decoder
                
                for j in range(sen_len):
                    position = torch.tensor([[i, j]], dtype=torch.float, device=device)
                    decoder_output, decoder_hidden, decoder_attention = self.decoder(
                        decoder_input, decoder_hidden, encoder_outputs_padded, position)
                    if j == 0 and cangtou and i < len(cangtou):
                        top_word = cangtou[i]
                        top_id = torch.LongTensor([word2id.get(top_word, vocab_size - 1)])
                    else:
                        top_id, top_word = get_next_word(decoder_output.data, decoded_words)
                        if top_word == 'N':
                            print('cannot meet requirements')
                            break
                    decoded_words.append(top_word)
                    decoder_input = top_id.reshape((1, 1)).detach()  # detach from history as input

                li = [word2id[word] for word in decoded_words]
                for k in range(self.input_max_len-len(li)):
                    li.append(0)
                input_sentence = torch.tensor([li], dtype=torch.long, device=device) 
                input_length = torch.tensor([len(decoded_words)], dtype=torch.long, device=device) 
                decoder_input = torch.tensor([[2]], device=device)  # '/'作为输入
                
        return decoded_words
        
Exemplo n.º 3
0
    def predict(self, data, cangtou, predict_param):
        with torch.no_grad():
            src_seq, src_pos = data

            # test
            # print('src:', [id2word[str(x.to(device).numpy())] for x in src_seq[0]])

            enc_output, *_ = self.encoder(src_seq, src_pos)

            # dec_seq = torch.tensor([[2]], device=device) # (beam size, decoded len) # 不要用SOS_token, 0会导致decoder_output为nan. 此处使用/
            # dec_pos = torch.tensor([[1]], device=device)

            sen_len = 7
            sen_num = 4
            decoded_words = ['/']
            rt_decoded_words = []

            for i in range(sen_num):
                for j in range(sen_len):
                    dec_seq = torch.tensor(
                        [[word2id[word] for word in decoded_words]],
                        dtype=torch.long,
                        device=device)
                    dec_pos = torch.tensor(
                        [[k + 1 for k in range(len(decoded_words))]],
                        dtype=torch.long,
                        device=device)
                    dec_output, *_ = self.decoder(dec_seq, dec_pos, src_seq,
                                                  enc_output)
                    dec_output = dec_output[:,
                                            -1, :]  # Pick the last step: (bh * bm) * d_h
                    word_prob = F.log_softmax(self.tgt_word_prj(dec_output),
                                              dim=1)  # (1, beam size, 4777)
                    top_id, top_word = get_next_word(
                        word_prob.data, rt_decoded_words)  # 应该是(1, 4777)
                    if top_word == 'N':
                        print('cannot meet requirements')
                        break
                    decoded_words.append(top_word)
                    rt_decoded_words.append(top_word)
                if not i == sen_num:
                    decoded_words.append('/')

            # test
            # print('pred:', [id2word[str(x.to(device).numpy())] for x in torch.argmax(decoded_words, dim=1)])

        return rt_decoded_words
Exemplo n.º 4
0
def evaluate(encoder, decoder, sentence):  #
    with torch.no_grad():  # for inference, no bp, save memory
        encoder_hidden = encoder.initHidden(1)

        input_length = sentence.size(1)
        encoder_outputs, encoder_hidden = encoder(sentence, [input_length],
                                                  encoder_hidden)

        # 将encoder_outputs padding至INPUT_MAX_LENGTH 因为attention中已经固定此维度大小为INPUT_MAX_LENGTH
        encoder_outputs_padded = torch.zeros(1,
                                             data_utils.INPUT_MAX_LENGTH,
                                             encoder.hidden_size,
                                             device=device)
        for b in range(1):
            for ei in range(input_length):
                encoder_outputs_padded[b, ei] = encoder_outputs[b, ei]

        decoder_input = torch.tensor([[data_utils.SOS_token]],
                                     device=device)  #
        decoder_hidden = encoder_hidden  # Use last hidden state from encoder to start decoder

        target_max_length = 28  # 暂定
        decoded_words = []

        for di in range(target_max_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(
                decoder_input, decoder_hidden, encoder_outputs_padded)
            topi, word = constrains.get_next_word(decoder_output.data,
                                                  decoded_words)
            if word == 'N':
                print('cannot meet requirements')
                break
            decoder_input = topi.reshape(
                (1, 1)).detach()  # detach from history as input
            decoded_words.append(word)
            # if (di + 1) % 7 == 0:
            #     decoded_words.append('/')
            # if decoder_input.item() == data_utils.EOS_token: #
            #     break

        return decoded_words