def toy_output(self, batch, dec_outputs, beam_dict): article = ids2sentence(batch.enc_inputs_oov[0].cpu().numpy(), self.vocab, batch.encoder_oovs[0]) summary = ids2sentence(batch.dec_targets_oov[0].cpu().numpy(), self.vocab, batch.encoder_oovs[0]) dec_pred = torch.argmax(dec_outputs[0].transpose(0, 1), dim=1) # log probs -> word idx pred = ids2sentence(dec_pred.cpu().detach().numpy(), self.vocab, batch.encoder_oovs[0]) beam_pred = None if beam_dict: beam_pred = ids2sentence(beam_dict['beam_dec_outputs'][0][0].cpu().detach().numpy(), self.vocab, batch.encoder_oovs[0]) print("Article: " + article) print("OOV: " + "|".join(batch.encoder_oovs[0])) print("Summary: " + summary) print("Predicted Greedy Search: " + pred) if beam_pred: print("Predicted Beam Search " + beam_pred)
def _reward(self, batch, preds): num_w = self._num_w(batch) system_summaries = [ids2sentence(tar, self.vocab, oov) for tar, oov in zip(batch.dec_targets_oov.cpu().numpy(), batch.encoder_oovs)] system_summaries = [make_readable(s, True, num_w[i]) for i, s in enumerate(system_summaries)] reward = pred2scores(batch, system_summaries, preds, self.vocab, num_w, batch_wise=False) return reward
def _map2win(self, inst, ws, ss): '''list - for each element in the batch, derive np.array of sequentialized window positions''' summary = ids2sentence(inst.decoder_target_pointer, self.vocab, inst.encoder_oovs) summary = make_readable(summary, True) article = ids2sentence(inst.encoder_pointer_idx, self.vocab, inst.encoder_oovs) article = make_readable(article, True) num_w = self.num_w(inst.encoder_pointer_idx, ws, ss) art_sents = article.split(" . ") sum_sents = summary.split(" . ") scores = [] for ss_ in sum_sents: for as_ in art_sents: vec_ss_ = self._aggregate(ss_) vec_as_ = self._aggregate(as_) score = vec_ss_ @ vec_as_ \ if vec_ss_ is not None and vec_as_ is not None \ else 0 scores.append(score) ''' try: _, rouge_dict, _ = setup_and_eval([as_ + " ."], [ss_ + " ."]) scores.append(rouge_dict['rouge_l_recall']) except: scores.append(0) ''' scores = np.array(scores).reshape(len(sum_sents), -1) ss2win = [] for ss_ in scores: as_ = np.argmax(ss_) start = " . ".join(art_sents[:as_]).split(" ").__len__() + 1 end = art_sents[as_].split(" ").__len__() + start win = self._pos2win(start, end, ws, ss, num_w) ss2win.append(win) sequential_ss2win = np.maximum.accumulate(ss2win) return sequential_ss2win
def metric_scores(self, batch, dec_outputs, beam_dict, with_meteor=False): num_w = self._num_w(batch) system_summaries = [ids2sentence(tar, self.vocab, oov) for tar, oov in zip(batch.dec_targets_oov.cpu().numpy(), batch.encoder_oovs)] system_summaries = [make_readable(s, True, num_w[i]) for i, s in enumerate(system_summaries)] metric_dict = None if dec_outputs is not None: dec_preds = torch.argmax(dec_outputs.transpose(1, 2), dim=-1).cpu().numpy() metric_dict = pred2scores(batch, system_summaries, dec_preds, self.vocab, self._num_w(batch), with_meteor=with_meteor) beam_metric_dict = None if beam_dict: dec_preds_beam = beam_dict['beam_dec_outputs'][:, 0].cpu().numpy() # 0 - only best hyp beam_metric_dict = pred2scores(batch, system_summaries, dec_preds_beam, self.vocab, self._num_w(batch), with_meteor=with_meteor) return metric_dict, beam_metric_dict
def pred2scores(batch, system_summaries, preds, vocab, num_w, batch_wise=True, with_meteor=False): model_summaries = [ ids2sentence(pred, vocab, oov) for pred, oov in zip(preds, batch.encoder_oovs) ] model_summaries = [ make_readable(s, True, num_w[i]) for i, s in enumerate(model_summaries) ] if batch_wise: try: count, rouge_batch_dict, meteor = setup_and_eval( system_summaries, model_summaries, with_meteor=with_meteor) out = { 'count': count, 'rouge_batch_dict': rouge_batch_dict, 'meteor': meteor } except: out = None else: out = [] for sys_sum, mod_sum in zip(system_summaries, model_summaries): try: _, rouge_dict, _ = setup_and_eval([sys_sum], [mod_sum]) out.append(rouge_dict['rouge_l_f_score']) except: out.append(0) return out
article = read_text_file(tok) instance = Instance(" ".join(article), None, STEPPER.vocab, CONFIG, None) print("Article: ", " ".join(article)) oovs = [instance.encoder_oovs] idx = torch.from_numpy(instance.encoder_pointer_idx).unsqueeze(0) idx_no_oov = mask_oov(idx, STEPPER.vocab) if CONFIG.encoder == 'Recurrent': enc_outputs, enc_state = STEPPER.encoder(idx_no_oov) dec_first_state = STEPPER.encoder.hidden_final(enc_state) else: # Transformer enc_outputs = STEPPER.encoder(idx_no_oov) dec_first_state = STEPPER.encoder.hidden_final(enc_outputs) STEPPER.bsdecoder.batch_size = 1 STEPPER.bsdecoder.dec_max_len = CONFIG.dec_max_len beam_dec_outputs = STEPPER.bsdecoder(enc_outputs, dec_first_state, idx) beam_pred = ids2sentence(beam_dec_outputs[0][0].cpu().detach().numpy(), STEPPER.vocab, oovs[0]) print("Prediction", beam_pred) if CFG['dir_mode']: pred_path = os.path.join(PREDICTION_DIR, f"prediction_{tok[-7:]}") with open(pred_path, "w") as pf: pf.write(make_readable(beam_pred)) os.remove(MAP_PATH)
def vis_dict(): with open(MAP_PATH, "a") as mf: mf.write(f"{PATH} \t {TOK_PATH}\n") command = [ 'java', '-cp', PARSER_JAR_PATH, 'edu.stanford.nlp.process.PTBTokenizer', '-preserveLines', '-ioFileList', MAP_PATH ] subprocess.call(" ".join(command), shell=True) article = read_text_file(TOK_PATH) article = " ".join(article) instance = Instance(article, None, STEPPER.vocab, CONFIG, None) idx = torch.from_numpy(instance.encoder_pointer_idx).unsqueeze(0) idx_no_oov = mask_oov(idx, STEPPER.vocab) if CONFIG.encoder == 'Recurrent': enc_outputs, enc_state = STEPPER.encoder(idx_no_oov) dec_first_state = STEPPER.encoder.hidden_final(enc_state) else: # Transformer enc_outputs = STEPPER.encoder(idx_no_oov) dec_first_state = STEPPER.encoder.hidden_final(enc_outputs) STEPPER.decoder.dec_max_len = CONFIG.dec_max_len dec_outputs, att_weights = STEPPER.decoder(enc_outputs, dec_first_state, None, idx) pred = torch.argmax(dec_outputs.transpose(1, 2), dim=-1).squeeze().cpu().numpy() pred = ids2sentence(pred, STEPPER.vocab, instance.encoder_oovs) if CONFIG.windowing and CONFIG.w_type == 'dynamic': num_w = STEPPER.decoder.windower.scheduler.num_w( instance.encoder_pointer_idx, CONFIG.ws, CONFIG.ss) if pred.find(STOP_DEC) != -1: eos_pos = tuple(re.finditer(STOP_DEC, pred)) last_eos = min(num_w, len(eos_pos)) last_eos_pos = eos_pos[last_eos - 1].start() pred = pred[:last_eos_pos].strip() pred = pred.replace(STOP_DEC, "-->") else: pred = make_readable(pred, False) transitions = None if CONFIG.windowing: if CONFIG.w_type == 'static': transitions = STEPPER.decoder.windower( instance.encoder_pointer_idx) else: #dynamic transitions = np.where(np.array(pred.split(" ")) == "-->")[0] + 1 slen_ = pred.split(" ").__len__() alen_ = article.split(" ").__len__() w_d_ = { "weights": att_weights.squeeze().detach().cpu().numpy()[:slen_, :alen_], "summary": pred.split(" "), "article": article.split(" "), "transitions": transitions } os.remove(MAP_PATH) return w_d_