def evaluate(loader, model, epoch, config, test=False): start = time.time() print('Evaluation start!') model.eval() if config.task == 'QG': references = loader.dataset.df.target_WORD.tolist() elif config.task == 'SM': # references = loader.dataset.df.target_tagged.tolist() references = loader.dataset.df.target_multiref.tolist() # references = loader.dataset.df.target.tolist() hypotheses = [[] for _ in range(max(config.n_mixture, config.decode_k))] hyp_focus = [[] for _ in range(max(config.n_mixture, config.decode_k))] hyp_attention = [[] for _ in range(max(config.n_mixture, config.decode_k))] if config.n_mixture > 1: assert config.decode_k == 1 use_multiple_hypotheses = True best_hypothesis = [] elif config.decode_k > 1: assert config.n_mixture == 1 use_multiple_hypotheses = True best_hypothesis = [] else: use_multiple_hypotheses = False best_hypothesis = None word2id = model.word2id id2word = model.id2word # PAD_ID = word2id['<pad>'] vocab_size = len(word2id) n_iter = len(loader) temp_time_start = time.time() with torch.no_grad(): for batch_idx, batch in enumerate(loader): if config.task == 'QG': source_WORD_encoding, source_len, \ target_WORD_encoding, target_len, \ source_WORD, target_WORD, \ answer_position_BIO_encoding, answer_WORD, \ ner, ner_encoding, \ pos, pos_encoding, \ case, case_encoding, \ focus_WORD, focus_mask, \ focus_input, answer_WORD_encoding, \ source_WORD_encoding_extended, oovs \ = [b.to(device) if isinstance(b, torch.Tensor) else b for b in batch] elif config.task == 'SM': source_WORD_encoding, source_len, \ target_WORD_encoding, target_len, \ source_WORD, target_WORD, \ focus_WORD, focus_mask, \ focus_input, \ source_WORD_encoding_extended, oovs \ = [b.to(device) if isinstance(b, torch.Tensor) else b for b in batch] answer_position_BIO_encoding = answer_WORD = ner_encoding = pos_encoding = case_encoding = None answer_WORD_encoding = None B, L = source_WORD_encoding.size() if config.use_focus: if config.eval_focus_oracle: generated_focus_mask = focus_mask input_mask = focus_mask else: # [B * n_mixture, L] focus_p = model.selector( source_WORD_encoding, answer_position_BIO_encoding, ner_encoding, pos_encoding, case_encoding, # mixture_id=mixture_id, # focus_input=focus_input, train=False) generated_focus_mask = (focus_p > config.threshold).long() # Repeat for Focus Selector if config.n_mixture > 1: source_WORD_encoding = repeat(source_WORD_encoding, config.n_mixture) if config.feature_rich: answer_position_BIO_encoding = repeat( answer_position_BIO_encoding, config.n_mixture) ner_encoding = repeat(ner_encoding, config.n_mixture) pos_encoding = repeat(pos_encoding, config.n_mixture) case_encoding = repeat(case_encoding, config.n_mixture) if config.model == 'PG': source_WORD_encoding_extended = repeat( source_WORD_encoding_extended, config.n_mixture) assert source_WORD_encoding.size(0) \ == source_WORD_encoding_extended.size(0) input_mask = generated_focus_mask else: input_mask = None generated_focus_mask = focus_mask # [B*n_mixturre, K, max_len] prediction, score = model.seq2seq( source_WORD_encoding, answer_WORD_encoding=answer_WORD_encoding, answer_position_BIO_encoding=answer_position_BIO_encoding, ner_encoding=ner_encoding, pos_encoding=pos_encoding, case_encoding=case_encoding, focus_mask=input_mask, target_WORD_encoding=None, source_WORD_encoding_extended=source_WORD_encoding_extended, train=False, decoding_type=config.decoding, beam_k=config.beam_k, max_dec_len=30 if config.task == 'QG' else 120 if config.task == 'SM' else exit(), temperature=config.temperature, diversity_lambda=config.diversity_lambda) prediction = prediction.view(B, config.n_mixture, config.beam_k, -1) prediction = prediction[:, :, 0:config.decode_k, :].tolist() if use_multiple_hypotheses: score = score.view(B, config.n_mixture, config.beam_k) score = score[:, :, :config.decode_k].view(B, -1) # [B] best_hyp_idx = score.argmax(dim=1).tolist() # Word IDs => Words for batch_j, (predicted_word_ids, source_words, target_words) \ in enumerate(zip(prediction, source_WORD, target_WORD)): if config.n_mixture > 1: assert config.decode_k == 1 for n in range(config.n_mixture): predicted_words = [] # [n_mixture, decode_k=1, dec_len] for word_id in predicted_word_ids[n][0]: # Generate if word_id < vocab_size: word = id2word[word_id] # End of sequence if word == '<eos>': break # Copy else: pointer_idx = word_id - vocab_size if config.model == 'NQG': word = source_words[pointer_idx] elif config.model == 'PG': try: word = oovs[batch_j][pointer_idx] except IndexError: import ipdb ipdb.set_trace() predicted_words.append(word) hypotheses[n].append(predicted_words) if use_multiple_hypotheses and best_hyp_idx[ batch_j] == n: best_hypothesis.append(predicted_words) elif config.n_mixture == 1: for k in range(config.decode_k): predicted_words = [] # [n_mixture=1, decode_k, dec_len] for word_id in predicted_word_ids[0][k]: # Generate if word_id < vocab_size: word = id2word[word_id] # End of sequence if word == '<eos>': break # Copy else: pointer_idx = word_id - vocab_size if config.model == 'NQG': word = source_words[pointer_idx] elif config.model == 'PG': try: word = oovs[batch_j][pointer_idx] except IndexError: import ipdb ipdb.set_trace() predicted_words.append(word) hypotheses[k].append(predicted_words) if use_multiple_hypotheses and best_hyp_idx[ batch_j] == k: best_hypothesis.append(predicted_words) # For visualization if config.use_focus: # [B * n_mixture, L] => [B, n_mixture, L] focus_p = focus_p.view(B, config.n_mixture, L) generated_focus_mask = generated_focus_mask.view( B, config.n_mixture, L) # target_L x [B * n_mixture, L] # => [B * n_mixture, L, target_L] # => [B, n_mixture, L, target_L] attention_list = torch.stack( model.seq2seq.decoder.attention_list, dim=2).view(B, config.n_mixture, L, -1) # n_mixture * [B, L] for n, focus_n in enumerate(focus_p.split(1, dim=1)): # [B, 1, L] => [B, L] focus_n = focus_n.squeeze(1).tolist() # B x [L] for f_n in focus_n: hyp_focus[n].append(f_n) # [L] # n_mixture * [B, L, target_L] for n, attention in enumerate(attention_list.split(1, dim=1)): # [B, 1, L, target_L] => [B, L, target_L] attention = attention.squeeze(1).tolist() # B x [L, target_L] for at in attention: hyp_attention[n].append(np.array(at)) # [L, target_L] if (not test) and batch_idx == 0: # if batch_idx > 260: n_samples_to_print = min(10, len(source_WORD)) for i in range(n_samples_to_print): s = source_WORD[i] # [L] g_m = generated_focus_mask[i].tolist() # [n_mixture, L] f_p = focus_p[i].tolist() # [n_mixture, L] print(f'[{i}]') print(f"Source Sequence: {' '.join(source_WORD[i])}") if config.task == 'QG': print(f"Answer: {' '.join(answer_WORD[i])}") if config.use_focus: print(f"Oracle Focus: {' '.join(focus_WORD[i])}") if config.task == 'QG': print(f"Target Question: {' '.join(target_WORD[i])}") elif config.task == 'SM': print(f"Target Summary: {' '.join(target_WORD[i])}") if config.n_mixture > 1: for n in range(config.n_mixture): if config.use_focus: print(f'(focus {n})') print( f"Focus Prob: {' '.join([f'({w}/{p:.2f})' for (w, p) in zip(s, f_p[n])])}" ) print( f"Generated Focus: {' '.join([w for w, m in zip(s, g_m[n]) if m == 1])}" ) if config.task == 'QG': print( f"Generated Question: {' '.join(hypotheses[n][B * batch_idx + i])}\n" ) elif config.task == 'SM': print( f"Generated Summary: {' '.join(hypotheses[n][B * batch_idx + i])}\n" ) else: for k in range(config.decode_k): if config.use_focus: print(f'(focus {k})') print( f"Focus Prob: {' '.join([f'({w}/{p:.2f})' for (w, p) in zip(s, f_p[k])])}" ) print( f"Generated Focus: {' '.join([w for w, m in zip(s, g_m[k]) if m == 1])}" ) if config.task == 'QG': print( f"Generated Question: {' '.join(hypotheses[k][B * batch_idx + i])}\n" ) elif config.task == 'SM': print( f"Generated Summary: {' '.join(hypotheses[k][B * batch_idx + i])}\n" ) if batch_idx % 100 == 0 or (batch_idx + 1) == n_iter: log_str = f'Evaluation | Epoch [{epoch}/{config.epochs}]' log_str += f' | Iteration [{batch_idx}/{n_iter}]' time_taken = time.time() - temp_time_start log_str += f' | Time taken: : {time_taken:.2f}' print(log_str) temp_time_start = time.time() time_taken = time.time() - start print(f"Generation Done! It took {time_taken:.2f}s") if test: print('Test Set Evaluation Result') score_calc_start = time.time() if not config.eval_focus_oracle and use_multiple_hypotheses: if config.task == 'QG': nested_references = [[r] for r in references] flat_hypothesis = best_hypothesis # bleu_4 = bleu.corpus_bleu(nested_references, flat_hypothesis, # smoothing_function=bleu.cm.method2) * 100 bleu_4 = bleu.corpus_bleu(nested_references, flat_hypothesis) * 100 print(f"BLEU-4: {bleu_4:.3f}") oracle_bleu_4 = bleu.oracle_bleu( hypotheses, references, n_process=min(4, n_cpus)) * 100 print(f"Oracle BLEU-4: {oracle_bleu_4:.3f}") self_bleu = bleu.self_bleu(hypotheses, n_process=min(4, n_cpus)) * 100 print(f"Self BLEU-4: {self_bleu:.3f}") avg_bleu = bleu.avg_bleu(hypotheses, references) * 100 print(f"Average BLEU-4: {avg_bleu:.3f}") metric_result = { 'BLEU-4': bleu_4, 'Oracle_BLEU-4': oracle_bleu_4, 'Self_BLEU-4': self_bleu, 'Average_BLEU-4': avg_bleu } elif config.task == 'SM': flat_hypothesis = best_hypothesis # summaries = [split_sentences(remove_tags(words)) # for words in flat_hypothesis] summaries = [split_sentences(words) for words in flat_hypothesis] # references = [split_tagged_sentences(ref) for ref in references] # summaries = [[" ".join(words)] # for words in flat_hypothesis] # references = [[ref] for ref in references] rouge_eval_start = time.time() rouge_dict = rouge.corpus_rouge(summaries, references, n_process=min(4, n_cpus)) print(f'ROUGE calc time: {time.time() - rouge_eval_start:.3f}s') for metric_name, score in rouge_dict.items(): print(f"{metric_name}: {score * 100:.3f}") ################## hypotheses_ = [[split_sentences(words) for words in hypothesis] for hypothesis in hypotheses] # references = [split_tagged_sentences(ref) for ref in references] # hypotheses_ = [[[" ".join(words)] for words in hypothesis] # for hypothesis in hypotheses] # references = [[ref] for ref in references] oracle_rouge_eval_start = time.time() oracle_rouge = rouge.oracle_rouge(hypotheses_, references, n_process=min(4, n_cpus)) print( f'Oracle ROUGE calc time: {time.time() - oracle_rouge_eval_start:.3f}s' ) for metric_name, score in oracle_rouge.items(): print(f"Oracle {metric_name}: {score * 100:.3f}") self_rouge_eval_start = time.time() self_rouge = rouge.self_rouge(hypotheses_, n_process=min(4, n_cpus)) print( f'Self ROUGE calc time: {time.time() - self_rouge_eval_start:.3f}s' ) for metric_name, score in self_rouge.items(): print(f"Self {metric_name}: {score * 100:.3f}") avg_rouge_eval_start = time.time() avg_rouge = rouge.avg_rouge(hypotheses_, references, n_process=min(4, n_cpus)) print( f'Average ROUGE calc time: {time.time() - avg_rouge_eval_start:.3f}s' ) for metric_name, score in avg_rouge.items(): print(f"Average {metric_name}: {score * 100:.3f}") metric_result = { 'ROUGE-1': rouge_dict['ROUGE-1'], 'ROUGE-2': rouge_dict['ROUGE-2'], 'ROUGE-L': rouge_dict['ROUGE-L'], 'Oracle_ROUGE-1': oracle_rouge['ROUGE-1'], 'Oracle_ROUGE-2': oracle_rouge['ROUGE-2'], 'Oracle_ROUGE-L': oracle_rouge['ROUGE-L'], 'Self_ROUGE-1': self_rouge['ROUGE-1'], 'Self_ROUGE-2': self_rouge['ROUGE-2'], 'Self_ROUGE-L': self_rouge['ROUGE-L'], 'Average_ROUGE-1': avg_rouge['ROUGE-1'], 'Average_ROUGE-2': avg_rouge['ROUGE-2'], 'Average_ROUGE-L': avg_rouge['ROUGE-L'], } metric_result = {k: v * 100 for k, v in metric_result.items()} else: if config.task == 'QG': nested_references = [[r] for r in references] flat_hypothesis = hypotheses[0] # bleu_4 = bleu.corpus_bleu(nested_references, flat_hypothesis, # smoothing_function=bleu.cm.method2) * 100 bleu_4 = bleu.corpus_bleu(nested_references, flat_hypothesis) # print(f"BLEU-4: {100 * bleu_4:.3f}") metric_result = {'BLEU-4': bleu_4} metric_result = {k: v * 100 for k, v in metric_result.items()} for metric_name, score in metric_result.items(): print(f"{metric_name}: {score:.3f}") elif config.task == 'SM': flat_hypothesis = hypotheses[0] # summaries = [split_sentences(remove_tags(words)) # for words in flat_hypothesis] summaries = [split_sentences(words) for words in flat_hypothesis] # references = [split_tagged_sentences(ref) for ref in references] # summaries = [[" ".join(words)] # for words in flat_hypothesis] # references = [[ref] for ref in references] metric_result = rouge.corpus_rouge(summaries, references, n_process=min(4, n_cpus)) metric_result = {k: v * 100 for k, v in metric_result.items()} for metric_name, score in metric_result.items(): print(f"{metric_name}: {score:.3f}") score_calc_time_taken = time.time() - score_calc_start print(f'Score calculation Done! It took {score_calc_time_taken:.2f}s') return metric_result, hypotheses, best_hypothesis, hyp_focus, hyp_attention
# [B * n_mixture, L] focus_logit = model.selector( source_WORD_encoding, answer_position_BIO_encoding= answer_position_BIO_encoding, ner_encoding=ner_encoding, pos_encoding=pos_encoding, case_encoding=case_encoding, mixture_id=None, focus_input=focus_input, train=True) B, L = source_WORD_encoding.size() # [B * n_mixture, L] repeated_target = repeat(focus_mask.float(), config.n_mixture) # [B * n_mixture, L] focus_loss = F.binary_cross_entropy_with_logits( focus_logit, repeated_target, reduction='none').view(B, config.n_mixture, L) pad_mask = (source_WORD_encoding == PAD_ID).view( B, 1, L) mixture_id = focus_loss.masked_fill( pad_mask, 0).sum(dim=2).argmin(dim=1) # 2) Train with the selected SELECTOR expert (M-Step) model.selector.train() # [B, L] focus_logit = model.selector(source_WORD_encoding,