示例#1
0
def validation_new(encoder, decoder, val_dataloader, lang_en, m_type):
    encoder.train(False)
    decoder.train(False)
    pred_corpus = []
    true_corpus = []
    running_loss = 0
    running_total = 0
    bl = BLEU_SCORE()
    for data in val_dataloader:
        encoder_i = data[0].cuda()
        bs, sl = encoder_i.size()[:2]
        en_h = encoder.initHidden(bs)
        en_out, en_hid = encoder(encoder_i, en_h)
        decoder_hidden = en_hid
        decoder_input = torch.tensor([[SOS_token]] * bs).cuda()
        d_out = []
        for i in range(sl + 20):
            if m_type == "attention":
                decoder_output, decoder_hidden = decoder(
                    decoder_input, decoder_hidden, en_out)
            else:
                decoder_output, decoder_hidden = decoder(
                    decoder_input, decoder_hidden)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach().view(-1, 1)
            d_out.append(topi.item())
            if topi.item() == EOS_token:
                break
        d_hid = decoder_hidden

        true_corpus.append(data[-1])
        pred_corpus.append(convert_id_list_2_sent(d_out, lang_en))
    score = bl.corpus_bleu(pred_corpus, [true_corpus], lowercase=True)[0]
    return score
示例#2
0
def validation(encoder, decoder, dataloader, loss_fun, lang_en, max_len,m_type):
    encoder.train(False)
    decoder.train(False)
    pred_corpus = []
    true_corpus = []
    running_loss = 0
    running_total = 0
    bl = BLEU_SCORE()
    for data in dataloader:
        encoder_i = data[0].cuda()
        decoder_i = data[1].cuda()
        bs,sl = encoder_i.size()[:2]
        out, hidden = encode_decode(encoder,decoder,encoder_i,decoder_i,max_len,m_type, rand_num = 0)
#         print("v",out.size(),decoder_i.size())
        outo = out.view(-1,lang_en.n_words)
        decoder_io = decoder_i.view(-1)
#         print("v",out.size(),decoder_i.size())

        loss = loss_fun(outo.float(), decoder_io.long())
        running_loss += loss.item() * bs
        running_total += bs
        pred = torch.max(out,dim = 2)[1]
        for t,p in zip(decoder_i,pred):
#             print(t,p)
            t,p = convert_idx_2_sent(t,lang_en), convert_idx_2_sent(p,lang_en)
            true_corpus.append(t)
            pred_corpus.append(p)
    score = bl.corpus_bleu(pred_corpus,[true_corpus],lowercase=True)[0]
    return running_loss/running_total, score
示例#3
0
def validation_new(encoder, decoder, val_dataloader, lang_en,lang_vi,m_type, verbose = False, replace_unk = False):
	encoder.eval()
	decoder.eval()
	pred_corpus = []
	true_corpus = []
	src_corpus = []
	running_loss = 0
	running_total = 0
	bl = BLEU_SCORE()
	attention_scores_for_all_val = []
	for data in val_dataloader:
		encoder_i = data[0].to(device)
		src_len = data[2].to(device)
		bs,sl = encoder_i.size()[:2]
		en_out,en_hid,en_c = encoder(encoder_i,src_len)
		max_src_len_batch = max(src_len).item()
		prev_hiddens = en_hid
		prev_cs = en_c
		decoder_input = torch.tensor([[SOS_token]]*bs).to(device)
		prev_output = torch.zeros((bs, en_out.size(-1))).to(device)
		d_out = []
		attention_scores = []
		for i in range(sl*2):
			out_vocab, prev_output,prev_hiddens, prev_cs, attention_score = decoder(decoder_input, prev_output, prev_hiddens, prev_cs, en_out, src_len)
			topv, topi = out_vocab.topk(1)
			d_out.append(topi.item())
			decoder_input = topi.squeeze().detach().view(-1,1)
			if m_type == 'attention':
				attention_scores.append(attention_score.unsqueeze(-1))
			if topi.item() == EOS_token:
				break
		if replace_unk:
			true_sent = convert_id_list_2_sent(data[1][0],lang_en)
			true_corpus.append(true_sent)
		else:
			true_corpus.append(data[-1])
		src_sent = convert_id_list_2_sent(data[0][0],lang_vi)
		src_corpus.append(src_sent)
		pred_sent = convert_id_list_2_sent(d_out,lang_en)
		pred_corpus.append(pred_sent)
		if m_type == 'attention':
			attention_scores = torch.cat(attention_scores, dim = -1)
			attention_scores_for_all_val.append(attention_scores)
		if verbose:
			print("True Sentence:",data[-1])
			print("Pred Sentence:", pred_sent)
			print('-*'*50)
	score = bl.corpus_bleu(pred_corpus,[true_corpus],lowercase=True)[0]
	return score, attention_scores_for_all_val, pred_corpus, src_corpus
示例#4
0
def validation(encoder, decoder, dataloader, loss_fun, lang_en):
    encoder.train(False)
    decoder.train(False)
    pred_corpus = []
    true_corpus = []
    running_loss = 0
    running_total = 0
    bl = BLEU_SCORE()
    for data in dataloader:
        encoder_i = data[0].cuda()
        decoder_i = data[1].cuda()
        bs, sl = encoder_i.size()[:2]
        out, hidden = encode_decode(encoder, decoder, encoder_i, decoder_i)
        loss = loss_fun(out.float(), decoder_i.long())
        running_loss += loss.item() * bs
        running_total += bs
        pred = torch.max(out, dim=1)[1]
        for t, p in zip(data[1], pred):
            t, p = convert_idx_2_sent(t,
                                      lang_en), convert_idx_2_sent(p, lang_en)
            true_corpus.append(t)
            pred_corpus.append(p)
    score = bl.corpus_bleu(pred_corpus, [true_corpus], lowercase=True)[0]
    return running_loss / running_total, score
示例#5
0
def validation_beam_search(encoder,
                           decoder,
                           val_dataloader,
                           lang_en,
                           beam_size,
                           verbose=False,
                           keep_unk=False,
                           return_attention=False):
    encoder.eval()
    decoder.eval()
    encoder = encoder.to(device)
    decoder = decoder.to(device)
    pred_corpus = []
    true_corpus = []
    running_loss = 0
    running_total = 0
    bl = BLEU_SCORE()

    attention_scores_for_all_val = []

    j = 0
    for data in val_dataloader:

        encoder_i = data[0].to(device)
        src_len = data[2].to(device)
        bs, sl = encoder_i.size()[:2]
        en_out, en_hid, en_c = encoder(encoder_i, src_len)
        max_src_len_batch = max(src_len).item()
        prev_hiddens = en_hid
        prev_cs = en_c
        decoder_input = torch.tensor([[SOS_token]] * bs).to(device)
        prev_output = torch.zeros((bs, en_out.size(-1))).to(device)
        list_decoder_input = [None] * beam_size
        beam_stop_flags = [False] * beam_size
        beam_score = torch.zeros((bs, beam_size)).to(device)
        list_d_outs = [[] for _ in range(beam_size)]
        select_beam_size = beam_size

        attention_scores = [[] for _ in range(beam_size)]

        for i in range(sl + 20):
            if i == 0:
                out_vocab, prev_output,prev_hiddens, prev_cs, attention_score = decoder(decoder_input,prev_output, \
                                 prev_hiddens,prev_cs, en_out,\
                                 src_len)
                bss, vocab_size = out_vocab.size()
                topv, topi = out_vocab.topk(beam_size)
                list_prev_output = [prev_output] * beam_size
                list_prev_hiddens = [prev_hiddens] * beam_size
                list_prev_cs = [prev_cs] * beam_size
                for b in range(beam_size):
                    beam_score[0][b] = topv[0][b].item()
                    list_decoder_input[b] = topi[0][b].squeeze().detach().view(
                        -1, 1)
                    list_d_outs[b].append(topi[0][b].item())

                    if decoder.att_layer is not None and return_attention:
                        attention_scores[b].append(
                            attention_score.unsqueeze(-1))

                    if topi[0][b].item() == EOS_token:
                        beam_stop_flags[b] = True
            else:
                beam_out_vocab = [None] * beam_size
                temp_out = [None] * beam_size
                temp_hid = [None] * beam_size
                temp_c = [None] * beam_size
                prev_d_outs = copy.deepcopy(list_d_outs)
                for b in range(beam_size):
                    if not beam_stop_flags[b]:
                        beam_out_vocab[b], temp_out[b], temp_hid[b], temp_c[b], attention_score =\
                         decoder(list_decoder_input[b],list_prev_output[b],list_prev_hiddens[b],list_prev_cs[b],\
                           en_out,src_len)
                        beam_out_vocab[
                            b] = beam_out_vocab[b] + beam_score[0][b]
                    if beam_stop_flags[b]:
                        beam_out_vocab[b] = torch.zeros(bss, vocab_size).fill_(
                            float('-inf')).to(device)
                beam_out_vocab = torch.cat(beam_out_vocab, dim=1)

                topv, topi = beam_out_vocab.topk(beam_size)
                id_for_hid = topi // vocab_size
                topi_idx = topi % vocab_size
                for b in range(beam_size):
                    if not beam_stop_flags[b]:
                        beam_score[0][b] = topv[0][b].item()
                        list_decoder_input[b] = topi_idx[0][b].squeeze(
                        ).detach().view(-1, 1)
                        list_d_outs[b] = copy.deepcopy(
                            prev_d_outs[id_for_hid[0][b]])
                        list_d_outs[b].append(topi_idx[0][b].item())
                        if topi_idx[0][b].item() == EOS_token:
                            beam_stop_flags[b] = True
                        else:
                            list_prev_output[b] = temp_out[id_for_hid[0][b]]
                            list_prev_hiddens[b] = temp_hid[id_for_hid[0][b]]
                            list_prev_cs[b] = temp_c[id_for_hid[0][b]]
                if all(beam_stop_flags):
                    break
        id_max_score = np.argmax(beam_score)
        d_out = list_d_outs[id_max_score]

        if keep_unk:
            true_sent = convert_id_list_2_sent(data[1][0], lang_en)
            true_corpus.append(true_sent)
        else:
            true_corpus.append(data[-1])

        pred_sent = convert_id_list_2_sent(d_out, lang_en)
        pred_corpus.append(pred_sent)

        if decoder.att_layer is not None and return_attention:
            attention_scores = torch.cat(attention_scores, dim=-1)
            attention_scores_for_all_val.append(attention_scores)

        if verbose:
            print("True Sentence:", data[-1])
            print("Pred Sentence:", pred_sent)
            print('-*' * 50)
    score = bl.corpus_bleu(pred_corpus, [true_corpus], lowercase=True)[0]

    if decoder.att_layer is not None and return_attention:
        return score, true_corpus, pred_corpus, attention_scores_for_all_val

    return score