def SARI_file(source, preds, refs, preprocess): files = [codecs.open(fis, "r", 'utf-8') for fis in [source, preds, refs]] scores = [] for src, pred, ref in zip(*files): references = [preprocess(r) for r in ref.split('\t')] scores.append(SARIsent(preprocess(src), preprocess(pred), references)) for fis in files: fis.close() return mean(scores)
def compute_sari(test_data, predictions): source_sentences = [get_processed_comment_str(ex.old_comment_subtokens) for ex in test_data] target_sentences = [[get_processed_comment_str(ex.new_comment_subtokens)] for ex in test_data] predicted_sentences = [' '.join(p) for p in predictions] inp = zip(source_sentences, target_sentences, predicted_sentences) scores = [] for source, target, predicted in inp: scores.append(SARIsent(source, predicted, target)) return 100*sum(scores)/float(len(scores))
def evaluate(self, dataset, vocab, model, args, max_edit_steps=50): """ Evaluate a model on given dataset and return performance during training Args: dataset: an object of data.Dataset() model (editNTS model): model to evaluate vocab: an object containing data.Vocab() args: args from the main methods Returns: loss (float): loss of the given model on the given dataset evaluated with teacher forcing sari: computed based on python script """ print_loss, print_loss_tf = [], [] bleu_list = [] ter = 0. sari_list = [] sys_out = [] print('Doing tokenized evaluation') for i, batch_df in dataset.batch_generator(batch_size=1, shuffle=False): model.eval() prepared_batch, syn_tokens_list = data.prepare_batch( batch_df, vocab, args.max_seq_len) # comp,scpn,simp org_ids = prepared_batch[0] org_lens = org_ids.ne(0).sum(1) org = sort_by_lens( org_ids, org_lens ) # inp=[inp_sorted, inp_lengths_sorted, inp_sort_order] org_pos_ids = prepared_batch[1] org_pos_lens = org_pos_ids.ne(0).sum(1) org_pos = sort_by_lens( org_pos_ids, org_pos_lens ) # inp=[inp_sorted, inp_lengths_sorted, inp_sort_order] out = prepared_batch[2][:, :] tar = prepared_batch[2][:, 1:] simp_ids = prepared_batch[3] best_seq_list = model.beamsearch(org, out, simp_ids, org_ids, org_pos, 5) # output_without_teacher_forcing = model(org, out, org_ids, org_pos, simp_ids,0.0) #can't compute loss for this one, can only do teacher forcing # output_teacher_forcing = model(org, out, org_ids, org_pos,simp_ids, 1.0) # if True: # the loss on validation is computed based on teacher forcing # ##################calculate loss # tar_lens = tar.ne(0).sum(1).float() # tar_flat = tar.contiguous().view(-1) # def compute_loss(output,tar_flat): #this function computes the loss based on model outputs and target in flat # loss = self.loss(output.contiguous().view(-1, vocab.count), tar_flat).contiguous() # loss[tar_flat == 1] = 0 # remove loss for UNK # loss = loss.view(tar.size()) # loss = loss.sum(1).float() # loss = loss / tar_lens # loss = loss.mean() # return loss # loss_tf = compute_loss(output_teacher_forcing,tar_flat) # print_loss_tf.append(loss_tf.item()) # the SARI and BLUE is computed based on model.eval without teacher forcing # for j in range(output_without_teacher_forcing.size()[0]): if True: ## write beam search here # try: if True: # example = batch_df.iloc[j] example = batch_df.iloc[0] # example_out = output_without_teacher_forcing[j, :, :] ##GREEDY # pred_action = torch.argmax(example_out, dim=1).view(-1).data.cpu().numpy() # edit_list_in_tokens = data.id2edits(pred_action, vocab) # ###BEST BEAM edit_list_in_tokens = data.id2edits( best_seq_list[0][1:], vocab) greedy_decoded_tokens = ' '.join( edit2sent(example['comp_tokens'], edit_list_in_tokens)) greedy_decoded_tokens = greedy_decoded_tokens.split( 'STOP')[0].split(' ') # tgt_tokens_translated = [vocab.i2w[i] for i in example['simp_ids']] sys_out.append(' '.join(greedy_decoded_tokens)) # prt = True if random.random() < 0.01 else False # if prt: # print('*' * 30) # # print('tgt_in_tokens_translated', ' '.join(tgt_tokens_translated)) # print('ORG', ' '.join(example['comp_tokens'])) # print('GEN', ' '.join(greedy_decoded_tokens)) # print('TGT', ' '.join(example['simp_tokens'])) # print('edit_list_in_tokens',edit_list_in_tokens) # print('gold labels', ' '.join(example['edit_labels'])) bleu_list.append( cal_bleu_score(greedy_decoded_tokens, example['simp_tokens'])) # calculate sari comp_string = ' '.join(example['comp_tokens']) simp_string = ' '.join(example['simp_tokens']) gen_string = ' '.join(greedy_decoded_tokens) sari_list.append( SARIsent(comp_string, gen_string, [simp_string])) print('loss_with_teacher_forcing', np.mean(print_loss_tf)) return np.mean(print_loss_tf), np.mean(bleu_list), np.mean( sari_list), sys_out