def validation_new(encoder, decoder, val_dataloader, lang_en, m_type): encoder.train(False) decoder.train(False) pred_corpus = [] true_corpus = [] running_loss = 0 running_total = 0 bl = BLEU_SCORE() for data in val_dataloader: encoder_i = data[0].cuda() bs, sl = encoder_i.size()[:2] en_h = encoder.initHidden(bs) en_out, en_hid = encoder(encoder_i, en_h) decoder_hidden = en_hid decoder_input = torch.tensor([[SOS_token]] * bs).cuda() d_out = [] for i in range(sl + 20): if m_type == "attention": decoder_output, decoder_hidden = decoder( decoder_input, decoder_hidden, en_out) else: decoder_output, decoder_hidden = decoder( decoder_input, decoder_hidden) topv, topi = decoder_output.topk(1) decoder_input = topi.squeeze().detach().view(-1, 1) d_out.append(topi.item()) if topi.item() == EOS_token: break d_hid = decoder_hidden true_corpus.append(data[-1]) pred_corpus.append(convert_id_list_2_sent(d_out, lang_en)) score = bl.corpus_bleu(pred_corpus, [true_corpus], lowercase=True)[0] return score
def validation(encoder, decoder, dataloader, loss_fun, lang_en, max_len,m_type): encoder.train(False) decoder.train(False) pred_corpus = [] true_corpus = [] running_loss = 0 running_total = 0 bl = BLEU_SCORE() for data in dataloader: encoder_i = data[0].cuda() decoder_i = data[1].cuda() bs,sl = encoder_i.size()[:2] out, hidden = encode_decode(encoder,decoder,encoder_i,decoder_i,max_len,m_type, rand_num = 0) # print("v",out.size(),decoder_i.size()) outo = out.view(-1,lang_en.n_words) decoder_io = decoder_i.view(-1) # print("v",out.size(),decoder_i.size()) loss = loss_fun(outo.float(), decoder_io.long()) running_loss += loss.item() * bs running_total += bs pred = torch.max(out,dim = 2)[1] for t,p in zip(decoder_i,pred): # print(t,p) t,p = convert_idx_2_sent(t,lang_en), convert_idx_2_sent(p,lang_en) true_corpus.append(t) pred_corpus.append(p) score = bl.corpus_bleu(pred_corpus,[true_corpus],lowercase=True)[0] return running_loss/running_total, score
def validation_new(encoder, decoder, val_dataloader, lang_en,lang_vi,m_type, verbose = False, replace_unk = False): encoder.eval() decoder.eval() pred_corpus = [] true_corpus = [] src_corpus = [] running_loss = 0 running_total = 0 bl = BLEU_SCORE() attention_scores_for_all_val = [] for data in val_dataloader: encoder_i = data[0].to(device) src_len = data[2].to(device) bs,sl = encoder_i.size()[:2] en_out,en_hid,en_c = encoder(encoder_i,src_len) max_src_len_batch = max(src_len).item() prev_hiddens = en_hid prev_cs = en_c decoder_input = torch.tensor([[SOS_token]]*bs).to(device) prev_output = torch.zeros((bs, en_out.size(-1))).to(device) d_out = [] attention_scores = [] for i in range(sl*2): out_vocab, prev_output,prev_hiddens, prev_cs, attention_score = decoder(decoder_input, prev_output, prev_hiddens, prev_cs, en_out, src_len) topv, topi = out_vocab.topk(1) d_out.append(topi.item()) decoder_input = topi.squeeze().detach().view(-1,1) if m_type == 'attention': attention_scores.append(attention_score.unsqueeze(-1)) if topi.item() == EOS_token: break if replace_unk: true_sent = convert_id_list_2_sent(data[1][0],lang_en) true_corpus.append(true_sent) else: true_corpus.append(data[-1]) src_sent = convert_id_list_2_sent(data[0][0],lang_vi) src_corpus.append(src_sent) pred_sent = convert_id_list_2_sent(d_out,lang_en) pred_corpus.append(pred_sent) if m_type == 'attention': attention_scores = torch.cat(attention_scores, dim = -1) attention_scores_for_all_val.append(attention_scores) if verbose: print("True Sentence:",data[-1]) print("Pred Sentence:", pred_sent) print('-*'*50) score = bl.corpus_bleu(pred_corpus,[true_corpus],lowercase=True)[0] return score, attention_scores_for_all_val, pred_corpus, src_corpus
def validation(encoder, decoder, dataloader, loss_fun, lang_en): encoder.train(False) decoder.train(False) pred_corpus = [] true_corpus = [] running_loss = 0 running_total = 0 bl = BLEU_SCORE() for data in dataloader: encoder_i = data[0].cuda() decoder_i = data[1].cuda() bs, sl = encoder_i.size()[:2] out, hidden = encode_decode(encoder, decoder, encoder_i, decoder_i) loss = loss_fun(out.float(), decoder_i.long()) running_loss += loss.item() * bs running_total += bs pred = torch.max(out, dim=1)[1] for t, p in zip(data[1], pred): t, p = convert_idx_2_sent(t, lang_en), convert_idx_2_sent(p, lang_en) true_corpus.append(t) pred_corpus.append(p) score = bl.corpus_bleu(pred_corpus, [true_corpus], lowercase=True)[0] return running_loss / running_total, score
def validation_beam_search(encoder, decoder, val_dataloader, lang_en, beam_size, verbose=False, keep_unk=False, return_attention=False): encoder.eval() decoder.eval() encoder = encoder.to(device) decoder = decoder.to(device) pred_corpus = [] true_corpus = [] running_loss = 0 running_total = 0 bl = BLEU_SCORE() attention_scores_for_all_val = [] j = 0 for data in val_dataloader: encoder_i = data[0].to(device) src_len = data[2].to(device) bs, sl = encoder_i.size()[:2] en_out, en_hid, en_c = encoder(encoder_i, src_len) max_src_len_batch = max(src_len).item() prev_hiddens = en_hid prev_cs = en_c decoder_input = torch.tensor([[SOS_token]] * bs).to(device) prev_output = torch.zeros((bs, en_out.size(-1))).to(device) list_decoder_input = [None] * beam_size beam_stop_flags = [False] * beam_size beam_score = torch.zeros((bs, beam_size)).to(device) list_d_outs = [[] for _ in range(beam_size)] select_beam_size = beam_size attention_scores = [[] for _ in range(beam_size)] for i in range(sl + 20): if i == 0: out_vocab, prev_output,prev_hiddens, prev_cs, attention_score = decoder(decoder_input,prev_output, \ prev_hiddens,prev_cs, en_out,\ src_len) bss, vocab_size = out_vocab.size() topv, topi = out_vocab.topk(beam_size) list_prev_output = [prev_output] * beam_size list_prev_hiddens = [prev_hiddens] * beam_size list_prev_cs = [prev_cs] * beam_size for b in range(beam_size): beam_score[0][b] = topv[0][b].item() list_decoder_input[b] = topi[0][b].squeeze().detach().view( -1, 1) list_d_outs[b].append(topi[0][b].item()) if decoder.att_layer is not None and return_attention: attention_scores[b].append( attention_score.unsqueeze(-1)) if topi[0][b].item() == EOS_token: beam_stop_flags[b] = True else: beam_out_vocab = [None] * beam_size temp_out = [None] * beam_size temp_hid = [None] * beam_size temp_c = [None] * beam_size prev_d_outs = copy.deepcopy(list_d_outs) for b in range(beam_size): if not beam_stop_flags[b]: beam_out_vocab[b], temp_out[b], temp_hid[b], temp_c[b], attention_score =\ decoder(list_decoder_input[b],list_prev_output[b],list_prev_hiddens[b],list_prev_cs[b],\ en_out,src_len) beam_out_vocab[ b] = beam_out_vocab[b] + beam_score[0][b] if beam_stop_flags[b]: beam_out_vocab[b] = torch.zeros(bss, vocab_size).fill_( float('-inf')).to(device) beam_out_vocab = torch.cat(beam_out_vocab, dim=1) topv, topi = beam_out_vocab.topk(beam_size) id_for_hid = topi // vocab_size topi_idx = topi % vocab_size for b in range(beam_size): if not beam_stop_flags[b]: beam_score[0][b] = topv[0][b].item() list_decoder_input[b] = topi_idx[0][b].squeeze( ).detach().view(-1, 1) list_d_outs[b] = copy.deepcopy( prev_d_outs[id_for_hid[0][b]]) list_d_outs[b].append(topi_idx[0][b].item()) if topi_idx[0][b].item() == EOS_token: beam_stop_flags[b] = True else: list_prev_output[b] = temp_out[id_for_hid[0][b]] list_prev_hiddens[b] = temp_hid[id_for_hid[0][b]] list_prev_cs[b] = temp_c[id_for_hid[0][b]] if all(beam_stop_flags): break id_max_score = np.argmax(beam_score) d_out = list_d_outs[id_max_score] if keep_unk: true_sent = convert_id_list_2_sent(data[1][0], lang_en) true_corpus.append(true_sent) else: true_corpus.append(data[-1]) pred_sent = convert_id_list_2_sent(d_out, lang_en) pred_corpus.append(pred_sent) if decoder.att_layer is not None and return_attention: attention_scores = torch.cat(attention_scores, dim=-1) attention_scores_for_all_val.append(attention_scores) if verbose: print("True Sentence:", data[-1]) print("Pred Sentence:", pred_sent) print('-*' * 50) score = bl.corpus_bleu(pred_corpus, [true_corpus], lowercase=True)[0] if decoder.att_layer is not None and return_attention: return score, true_corpus, pred_corpus, attention_scores_for_all_val return score