コード例 #1
0
ファイル: evaluate.py プロジェクト: dowobeha/seq2seq
def evaluate(vocab: Vocabulary, corpus_filename: str, encoder: EncoderRNN,
             decoder: AttnDecoderRNN, max_src_length: int,
             max_tgt_length: int):

    device: torch.device = torch.device(
        "cuda:0" if torch.cuda.is_available() else "cpu")

    encoder.to(device)
    decoder.to(device)

    encoder.eval()
    decoder.eval()

    with torch.no_grad():

        corpus = Corpus(
            filename=corpus_filename,
            max_src_length=max_src_length,  # decoder.max_src_length,
            vocab=vocab,
            device=device)

        for batch in torch.utils.data.DataLoader(dataset=corpus, batch_size=1):

            input_tensor: torch.Tensor = batch["data"].permute(1, 0)

            encoder_outputs = encoder.encode_sequence(input_tensor)

            decoder_output = decoder.decode_sequence(
                encoder_outputs=encoder_outputs,
                start_symbol=corpus.characters.start_of_sequence.integer,
                max_length=max_tgt_length)
            _, top_i = decoder_output.topk(k=1)

            predictions = top_i.squeeze(dim=2).squeeze(dim=1).tolist()

            predicted_string = "".join(
                [corpus.characters[i].string for i in predictions])

            print(predicted_string)
コード例 #2
0
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        # 循环,这里只使用长度限制,后面处理的时候把EOS去掉了。
        for _ in range(max_length):
            # Decoder forward一步
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            # decoder_outputs是(batch=1, vob_size)
            # 使用max返回概率最大的词和得分
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
            # 把解码结果保存到all_tokens和all_scores里
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            # decoder_input是当前时刻输出的词的ID,这是个一维的向量,因为max会减少一维。
            # 但是decoder要求有一个batch维度,因此用unsqueeze增加batch维度。
            decoder_input = torch.unsqueeze(decoder_input, 0)
        # 返回所有的词和得分。
        return all_tokens, all_scores

######################################################################
# Run Evaluation
#

# 进入eval模式,从而去掉dropout。
encoder.eval()
decoder.eval()

# 构造searcher对象
searcher = GreedySearchDecoder(encoder, decoder)

# 测试
evaluateInput(encoder, decoder, searcher, voc)