def load(self): if self.o.use_pair: wordpos_vocab_list = [ w.strip().split() for w in open(self.o.vocab_path) ][:self.o.num_lex + 2] word_vocab_list = [wp[0] for wp in wordpos_vocab_list] word_vocab = Vocab.from_list(word_vocab_list, unk='<UNK>', pad='<PAD>') self.pos_dict = { word_vocab[wp[0]]: list(map(BLLIP_POS.__getitem__, wp[1:])) for wp in wordpos_vocab_list if len(wp) > 1 } else: word_vocab_list = [w.strip() for w in open(self.o.vocab_path) ][:self.o.num_lex + 2] word_vocab = Vocab.from_list(word_vocab_list, unk='<UNK>', pad='<PAD>') self.train_ds = ConllDataset(self.o.train_ds, pos_vocab=BLLIP_POS, word_vocab=word_vocab) self.dev_ds = ConllDataset(self.o.dev_ds, pos_vocab=BLLIP_POS, word_vocab=word_vocab) self.test_ds = ConllDataset(self.o.test_ds, pos_vocab=BLLIP_POS, word_vocab=word_vocab) if self.o.pretrained_ds: self.pretrained_ds = ConllDataset(self.o.pretrained_ds, pos_vocab=BLLIP_POS, word_vocab=word_vocab) else: self.pretrained_ds = None self.dev_ds.build_batchs(self.o.batch_size) self.test_ds.build_batchs(self.o.batch_size) if self.o.use_pair: self.o.num_lex = sum([len(p) for p in self.pos_dict.values()]) self.o.max_len = 10 self.o.num_tag = len(BLLIP_POS) + self.o.num_lex if self.o.emb_path: self.word_emb = np.load(self.o.emb_path)[:self.o.num_lex + 2] else: self.word_emb = None self.out_pos_emb = np.load( self.o.out_pos_emb_path) if self.o.out_pos_emb_path else None self.pos_emb = np.load( self.o.pos_emb_path) if self.o.pos_emb_path else None
def decoder(): # 配置GPU use_gpu = True config_gpu(use_gpu=use_gpu) # 读取字典 和 词向量矩阵 vocab = Vocab(config.path_vocab, config.vocab_size) wvtool = WVTool(ndim=config.emb_dim) embedding_matrix = wvtool.load_embedding_matrix(path_embedding_matrix=config.path_embedding_matrixt) # 构建模型 logger.info('构建Seq2Seq模型 ...') model=Seq2Seq(config.beam_size,embedding_matrix=embedding_matrix) # 存档点管理 ckpt = tf.train.Checkpoint(Seq2Seq=model) ckpt_manager = tf.train.CheckpointManager(checkpoint=ckpt, directory=config.dir_ckpt, max_to_keep=10) if ckpt_manager.latest_checkpoint: ckpt.restore(ckpt_manager.latest_checkpoint) logger.info('decoder模型存档点加载自: {}'.format(ckpt_manager.latest_checkpoint)) else: logger.info('无可加载的存档点') # 获取训练数据 batcher = Batcher(config.path_seg_test, vocab, mode='decode', batch_size=config.beam_size, single_pass=True) time.sleep(20) # 训练模型 # 输入:训练数据barcher,模型,词表,存档点,词向量矩阵 batch_decode(batcher, model=model,vocab=vocab)
def __init__(self, args, model_name=None): self.args = args vocab = args.vocab_path if args.vocab_path is not None else config.vocab_path self.vocab = Vocab(vocab, config.vocab_size, config.embeddings_file, args) self.train_batcher = Batcher(args.train_data_path, self.vocab, mode='train', batch_size=args.batch_size, single_pass=False, args=args) self.eval_batcher = Batcher(args.eval_data_path, self.vocab, mode='eval', batch_size=args.batch_size, single_pass=True, args=args) time.sleep(30) if model_name is None: self.train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time()))) else: self.train_dir = os.path.join(config.log_root, model_name) if not os.path.exists(self.train_dir): os.mkdir(self.train_dir) self.model_dir = os.path.join(self.train_dir, 'model') if not os.path.exists(self.model_dir): os.mkdir(self.model_dir)
def __init__(self, par): file = par.cfg self.par = par self.cuda = False self.reuse_words = False self.cell = None self.num_layers = None self.bidirectional = None self.hidden_size = None self.emb_src_size = None self.emb_tgt_size = None self.attention = 'dot' self.coverage = None self.pointer = None self.opt_method = None self.max_grad_norm = None self.n_iters_sofar = None with open(file, 'r') as stream: opts = yaml.load(stream) for o,v in opts.items(): if o=="cuda": self.cuda = bool(v) and torch.cuda.is_available() elif o=="cell": self.cell = v.lower() elif o=="reuse_words": self.reuse_words = bool(v) elif o=="num_layers": self.num_layers = int(v) elif o=="bidirectional": self.bidirectional = bool(v) elif o=="hidden_size": self.hidden_size = int(v) elif o=="emb_src_size": self.emb_src_size = int(v) elif o=="emb_tgt_size": self.emb_tgt_size = int(v) elif o=="attention": self.attention = v elif o=="coverage": self.coverage = bool(v) elif o=="pointer": self.pointer = bool(v) elif o=="opt_method": self.opt_method = v elif o=="max_grad_norm": self.max_grad_norm = float(v) else: sys.exit("error: unparsed {} config option.".format(o)) if self.par.voc_src is None: sys.exit('error: missing -voc_src option') if self.coverage and self.attention != 'concat': sys.exit('error: option coverage must be used with attention: \'concat\'') self.svoc = Vocab(self.par.voc_src) if self.reuse_words: self.tvoc = self.svoc self.emb_tgt_size = self.emb_src_size else: if self.par.voc_tgt is None: sys.exit('error: missing -voc_tgt option\n') self.tvoc = Vocab(self.par.voc_tgt) self.out()
def __init__(self, model): self._decode_dir = os.path.join(config.log_root, 'decode_%s' % ("model2")) self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref') self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir') for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]: if not os.path.exists(p): os.mkdir(p) self.vocab = Vocab(config.vocab_path, config.vocab_size) self.batcher = Batcher(config.decode_data_path, self.vocab, mode='decode', batch_size=1, single_pass=True) self.model = model
def __init__(self, args, model_file_path, save_path): model_name = os.path.basename(model_file_path) self.args = args self._decode_dir = os.path.join(config.log_root, save_path, 'decode_%s' % (model_name)) self._structures_dir = os.path.join(self._decode_dir, 'structures') self._sent_single_heads_dir = os.path.join(self._decode_dir, 'sent_heads_preds') self._sent_single_heads_ref_dir = os.path.join(self._decode_dir, 'sent_heads_ref') self._contsel_dir = os.path.join(self._decode_dir, 'content_sel_preds') self._contsel_ref_dir = os.path.join(self._decode_dir, 'content_sel_ref') self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref') self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir') self._rouge_ref_file = os.path.join(self._decode_dir, 'rouge_ref.json') self._rouge_pred_file = os.path.join(self._decode_dir, 'rouge_pred.json') self.stat_res_file = os.path.join(self._decode_dir, 'stats.txt') self.sent_count_file = os.path.join(self._decode_dir, 'sent_used_counts.txt') for p in [ self._decode_dir, self._structures_dir, self._sent_single_heads_ref_dir, self._sent_single_heads_dir, self._contsel_ref_dir, self._contsel_dir, self._rouge_ref_dir, self._rouge_dec_dir ]: if not os.path.exists(p): os.mkdir(p) vocab = args.vocab_path if args.vocab_path is not None else config.vocab_path self.vocab = Vocab(vocab, config.vocab_size, config.embeddings_file, args) self.batcher = Batcher(args.decode_data_path, self.vocab, mode='decode', batch_size=args.beam_size, single_pass=True, args=args) self.batcher.setup_queues() time.sleep(30) self.model = Model(args, self.vocab).to(device) self.model.eval()
def main(): vocab = Vocab(config.vocab_path, config.vocab_size) train_batcher = Batcher(config.train_data_path, vocab, mode='train', batch_size=config.batch_size, single_pass=False) eval_batcher = Batcher(config.eval_data_path , vocab, mode='train', batch_size=config.batch_size, single_pass=False) model = build_model(config) criterion = LabelSmoothing(config.vocab_size, train_batcher.pad_id, smoothing=.1) if args.mode=='train': train(config.max_iters, train_batcher, eval_batcher, model, criterion, config, args.save_path) elif args.mode=='eval': eval(config, args.model)
def __init__(self, model_file_path): self.vocab = Vocab(config.vocab_path, config.vocab_size) self.batcher = Batcher(config.eval_data_path, self.vocab, mode='eval', batch_size=config.batch_size, single_pass=True) time.sleep(15) model_name = os.path.basename(model_file_path) eval_dir = os.path.join(config.log_root, 'eval_%s' % (model_name)) if not os.path.exists(eval_dir): os.mkdir(eval_dir) self.summary_writer = SummaryWriter(eval_dir) self.model = Model(model_file_path, is_eval=True)
def load(self): pos_vocab_list = [w.strip() for w in open(self.o.vocab_path)] pos_vocab = Vocab.from_list(pos_vocab_list) self.train_ds = ConllDataset(self.o.train_ds, pos_vocab=pos_vocab) self.dev_ds = ConllDataset(self.o.dev_ds, pos_vocab=pos_vocab) self.test_ds = ConllDataset(self.o.test_ds, pos_vocab=pos_vocab) if self.o.pretrained_ds: self.pretrained_ds = ConllDataset(self.o.pretrained_ds, pos_vocab=pos_vocab) else: self.pretrained_ds = None self.dev_ds.build_batchs(self.o.batch_size) self.test_ds.build_batchs(self.o.batch_size) if self.o.emb_path: self.word_emb = np.load(self.o.emb_path) else: self.word_emb = None
# load examples logging.info("Loading data...") if dataset == "A": train = load_pickle("./data/SemEval/Task{0}/train.pkl".format(dataset)) val = load_pickle("./data/SemEval/Task{0}/val.pkl".format(dataset)) if test_mode: test = load_pickle("./data/SemEval/TaskA/test.pkl") train = merge_splits(train, val) val = test logging.info("Number of training examples: {0}".format(len(train))) logging.info("Number of validation examples: {0}".format(len(val))) for ex in train[0][:3]: logging.info("Examples: {0}".format(ex)) logging.info("Building vocab...") vocab = Vocab(train, min_freq, max_vocab_size) vocab_size = len(vocab.word2id) logging.info("Vocab size: {0}".format(vocab_size)) # build vocab and data # use pretrained word embedding logging.info("Loading word embedding from Magnitude...") home = os.path.expanduser("~") if embedding_size in [50, 100, 200]: vectors = Magnitude( os.path.join( home, "WordEmbedding/glove.twitter.27B.{0}d.magnitude".format( embedding_size))) elif embedding_size in [300]: # vectors = Magnitude(os.path.join(home, "WordEmbedding/GoogleNews-vectors-negative{0}.magnitude".format(embedding_size))) vectors = Magnitude(
class BeamSearchDecoder: def __init__(self, model): self._decode_dir = os.path.join(config.log_root, 'decode_%s' % ("model2")) self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref') self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir') for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]: if not os.path.exists(p): os.mkdir(p) self.vocab = Vocab(config.vocab_path, config.vocab_size) self.batcher = Batcher(config.decode_data_path, self.vocab, mode='decode', batch_size=1, single_pass=True) self.model = model def beam_search(self, batch, conf): # batch should have only one example enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros = helper.prepare_src_batch( batch, conf) encoder_output = self.model.encode(enc_batch, enc_padding_mask) hyps_list = [ Hypothesis(tokens=[self.vocab.word2id(data.START_DECODING)], log_probs=[0.0]) for _ in range(1) ] results = [] steps = 0 while steps < config.max_dec_len and len(results) < config.beam_size: hyp_tokens = [h.tokens for h in hyps_list] np_hyp_tokens = np.asarray(hyp_tokens) np_hyp_tokens[ np_hyp_tokens >= self.vocab.size()] = self.vocab.word2id( data.UNKNOWN_TOKEN) yt = torch.LongTensor(np_hyp_tokens).to(device) temp_enc_out = encoder_output.repeat(yt.size(0), 1, 1) out, _ = self.model.decode(temp_enc_out, yt, enc_padding_mask, helper.subsequent_mask(yt.size(-1))) extra_zeros_ip = None if extra_zeros is not None: extra_zeros_ip = extra_zeros[:, 0:steps + 1, :].repeat( yt.size(0), 1, 1) if conf.coverage: log_probs, _ = self.model.generator(out, temp_enc_out, enc_padding_mask, enc_batch_extend_vocab, extra_zeros_ip) else: log_probs = self.model.generator(out, temp_enc_out, enc_padding_mask, enc_batch_extend_vocab, extra_zeros_ip) log_probs = log_probs.squeeze(1) topk_log_probs, topk_ids = torch.topk(log_probs, config.beam_size * 2) if len(topk_log_probs.size()) == 3: topk_log_probs = topk_log_probs[:, -1, :].squeeze(1) topk_ids = topk_ids[:, -1, :].squeeze(1) all_hyps = [] num_orig_hyps = 1 if steps == 0 else len(hyps_list) for i in range(num_orig_hyps): h = hyps_list[i] # print(h.tokens) for j in range(config.beam_size * 2): # for each of the top beam_size hyps: hyp = h.extend(token=topk_ids[i, j].item(), log_prob=topk_log_probs[i, j].item()) all_hyps.append(hyp) hyps_list = [] sorted_hyps = sorted(all_hyps, key=lambda h: h.avg_log_prob, reverse=True) for h in sorted_hyps: if h.latest_token == self.vocab.word2id(data.STOP_DECODING): if steps >= config.min_dec_steps: results.append(h) else: hyps_list.append(h) if len(hyps_list) == config.beam_size or len( results) == config.beam_size: break steps += 1 if len(results) == 0: results = hyps_list results_sorted = sorted(results, key=lambda h: h.avg_log_prob, reverse=True) return results_sorted[0] def decode(self, conf): self.model.eval() start = time.time() counter = 0 batch = self.batcher.next_batch() i = 0 while batch is not None: i += 1 if i % 10 == 0: print(i) # Run beam search to get best Hypothesis best_summary = self.beam_search(batch, conf) # Extract the output ids from the hypothesis and convert back to words output_ids = [int(t) for t in best_summary.tokens[1:]] # print(output_ids) decoded_words = data.outputids2words( output_ids, self.vocab, (batch.art_oovs[0] if config.pointer_gen else None)) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index(data.STOP_DECODING) decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words original_abstract_sents = batch.original_abstracts_sents[0] write_for_rouge(original_abstract_sents, decoded_words, counter, self._rouge_ref_dir, self._rouge_dec_dir) counter += 1 if counter % 1000 == 0: print('%d example in %d sec' % (counter, time.time() - start)) start = time.time() batch = self.batcher.next_batch() results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir) rouge_log(results_dict, self._decode_dir)
logging.info("RESUME TRAINING") audio_conf = dict(sample_rate=args.sample_rate, window_size=args.window_size, window_stride=args.window_stride, window=args.window, noise_dir=args.noise_dir, noise_prob=args.noise_prob, noise_levels=(args.noise_min, args.noise_max)) logging.info(audio_conf) with open(args.labels_path, encoding="utf-8") as label_file: labels = json.load(label_file) vocab = Vocab() for label in labels: vocab.add_token(label) vocab.add_label(label) train_data_list = [] for i in range(len(args.train_manifest_list)): if args.feat == "spectrogram": train_data = SpectrogramDataset( vocab, args, audio_conf, manifest_filepath_list=args.train_manifest_list, normalize=True, augment=args.augment, input_type=args.input_type,
class BeamSearch(object): def __init__(self, args, model_file_path, save_path): model_name = os.path.basename(model_file_path) self.args = args self._decode_dir = os.path.join(config.log_root, save_path, 'decode_%s' % (model_name)) self._structures_dir = os.path.join(self._decode_dir, 'structures') self._sent_single_heads_dir = os.path.join(self._decode_dir, 'sent_heads_preds') self._sent_single_heads_ref_dir = os.path.join(self._decode_dir, 'sent_heads_ref') self._contsel_dir = os.path.join(self._decode_dir, 'content_sel_preds') self._contsel_ref_dir = os.path.join(self._decode_dir, 'content_sel_ref') self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref') self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir') self._rouge_ref_file = os.path.join(self._decode_dir, 'rouge_ref.json') self._rouge_pred_file = os.path.join(self._decode_dir, 'rouge_pred.json') self.stat_res_file = os.path.join(self._decode_dir, 'stats.txt') self.sent_count_file = os.path.join(self._decode_dir, 'sent_used_counts.txt') for p in [ self._decode_dir, self._structures_dir, self._sent_single_heads_ref_dir, self._sent_single_heads_dir, self._contsel_ref_dir, self._contsel_dir, self._rouge_ref_dir, self._rouge_dec_dir ]: if not os.path.exists(p): os.mkdir(p) vocab = args.vocab_path if args.vocab_path is not None else config.vocab_path self.vocab = Vocab(vocab, config.vocab_size, config.embeddings_file, args) self.batcher = Batcher(args.decode_data_path, self.vocab, mode='decode', batch_size=args.beam_size, single_pass=True, args=args) self.batcher.setup_queues() time.sleep(30) self.model = Model(args, self.vocab).to(device) self.model.eval() def sort_beams(self, beams): return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True) def extract_structures(self, batch, sent_attention_matrix, doc_attention_matrix, count, use_cuda, sent_scores): fileName = os.path.join(self._structures_dir, "%06d_struct.txt" % count) fp = open(fileName, "w") fp.write("Doc: " + str(count) + "\n") #exit(0) doc_attention_matrix = doc_attention_matrix[:, :] #this change yet to be tested! l = batch.enc_doc_lens[0].item() doc_sent_no = 0 # for i in range(l): # printstr = '' # sent = batch.enc_batch[0][i] # #scores = str_scores_sent[sent_no][0:l, 0:l] # token_count = 0 # for j in range(batch.enc_sent_lens[0][i].item()): # token = sent[j].item() # printstr += self.vocab.id2word(token)+" " # token_count = token_count + 1 # #print(printstr) # fp.write(printstr+"\n") # # scores = sent_attention_matrix[doc_sent_no][0:token_count, 0:token_count] # shape2 = sent_attention_matrix[doc_sent_no][0:token_count,0:token_count].size() # row = torch.ones([1, shape2[1]+1]).cuda() # column = torch.zeros([shape2[0], 1]).cuda() # new_scores = torch.cat([column, scores], dim=1) # new_scores = torch.cat([row, new_scores], dim=0) # # heads, tree_score = chu_liu_edmonds(new_scores.data.cpu().numpy().astype(np.float64)) # #print(heads, tree_score) # fp.write(str(heads)+" ") # fp.write(str(tree_score)+"\n") # doc_sent_no+=1 shape2 = doc_attention_matrix[0:l, 0:l + 1].size() row = torch.zeros([1, shape2[1]]).cuda() #column = torch.zeros([shape2[0], 1]).cuda() scores = doc_attention_matrix[0:l, 0:l + 1] #new_scores = torch.cat([column, scores], dim=1) new_scores = torch.cat([row, scores], dim=0) val, root_edge = torch.max(new_scores[:, 0], dim=0) root_score = torch.zeros([shape2[0] + 1, 1]).cuda() root_score[root_edge] = 1 new_scores[:, 0] = root_score.squeeze() #print(new_scores) #print(new_scores.sum(dim=0)) #print(new_scores.sum(dim=1)) #print(new_scores.size()) heads, tree_score = chu_liu_edmonds( new_scores.data.cpu().numpy().astype(np.float64)) height = find_height(heads) leaf_nodes = leaf_node_proportion(heads) #print(heads, tree_score) fp.write("\n") sentences = str(batch.original_articles[0]).split("<split1>") for idx, sent in enumerate(sentences): fp.write(str(idx) + "\t" + str(sent) + "\n") #fp.write(str("\n".join(batch.original_articles[0].split("<split1>"))+"\n") fp.write(str(heads) + " ") fp.write(str(tree_score) + "\n") fp.write(str(height) + "\n") s = sent_scores[0].data.cpu().numpy() for val in s: fp.write(str(val)) fp.close() #exit() structure_info = dict() structure_info['heads'] = heads structure_info['height'] = height structure_info['leaf_nodes'] = leaf_nodes return structure_info def decode(self): start = time.time() counter = 0 sent_counter = [] avg_max_seq_len_list = [] copied_sequence_len = Counter() copied_sequence_per_sent = [] article_copy_id_count_tot = Counter() sentence_copy_id_count = Counter() novel_counter = Counter() repeated_counter = Counter() summary_sent_count = Counter() summary_sent = [] article_sent = [] summary_len = [] abstract_ref = [] abstract_pred = [] sentence_count = [] tot_sentence_id_count = Counter() height_avg = [] leaf_node_proportion_avg = [] precision_tree_dist = [] recall_tree_dist = [] batch = self.batcher.next_batch() height_counter = Counter() leaf_nodes_counter = Counter() sent_count_fp = open(self.sent_count_file, 'w') counts = { 'token_consel_num_correct': 0, 'token_consel_num': 0, 'sent_single_heads_num_correct': 0, 'sent_single_heads_num': 0, 'sent_all_heads_num_correct': 0, 'sent_all_heads_num': 0, 'sent_all_child_num_correct': 0, 'sent_all_child_num': 0 } no_batches_processed = 0 while batch is not None: # Run beam search to get best Hypothesis #start = time.process_time() has_summary, best_summary, sample_predictions, sample_counts, structure_info, adj_mat = self.get_decoded_outputs( batch, counter) #print('Time taken for decoder: ', time.process_time() - start) # token_contsel_tot_correct += token_consel_num_correct # token_contsel_tot_num += token_consel_num # sent_heads_tot_correct += sent_heads_num_correct # sent_heads_tot_num += sent_heads_num if args.predict_contsel_tags: no_words = batch.enc_word_lens[0] prediction = sample_predictions['token_contsel_prediction'][ 0:no_words] ref = batch.contsel_tags[0] write_tags(prediction, ref, counter, self._contsel_dir, self._contsel_ref_dir) counts['token_consel_num_correct'] += sample_counts[ 'token_consel_num_correct'] counts['token_consel_num'] += sample_counts['token_consel_num'] if args.predict_sent_single_head: no_sents = batch.enc_doc_lens[0] prediction = sample_predictions[ 'sent_single_heads_prediction'][0:no_sents].tolist() ref = batch.original_parent_heads[0] write_tags(prediction, ref, counter, self._sent_single_heads_dir, self._sent_single_heads_ref_dir) counts['sent_single_heads_num_correct'] += sample_counts[ 'sent_single_heads_num_correct'] counts['sent_single_heads_num'] += sample_counts[ 'sent_single_heads_num'] if args.predict_sent_all_head: counts['sent_all_heads_num_correct'] += sample_counts[ 'sent_all_heads_num_correct'] counts['sent_all_heads_num'] += sample_counts[ 'sent_all_heads_num'] if args.predict_sent_all_child: counts['sent_all_child_num_correct'] += sample_counts[ 'sent_all_child_num_correct'] counts['sent_all_child_num'] += sample_counts[ 'sent_all_child_num'] if has_summary == False: batch = self.batcher.next_batch() continue # Extract the output ids from the hypothesis and convert back to words output_ids = [int(t) for t in best_summary.tokens[1:]] decoded_words = data.outputids2words( output_ids, self.vocab, (batch.art_oovs[0] if self.args.pointer_gen else None)) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index(data.STOP_DECODING) decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words original_abstract_sents = batch.original_abstracts_sents[0] summary_len.append(len(decoded_words)) assert adj_mat is not None, "Explicit matrix is none." assert structure_info['heads'] is not None, "Heads is none." precision, recall = tree_distance( structure_info['heads'], adj_mat.cpu().data.numpy()[0, :, :]) if precision is not None and recall is not None: precision_tree_dist.append(precision) recall_tree_dist.append(recall) height_counter[structure_info['height']] += 1 height_avg.append(structure_info['height']) leaf_node_proportion_avg.append(structure_info['leaf_nodes']) leaf_nodes_counter[np.floor(structure_info['leaf_nodes'] * 10)] += 1 abstract_ref.append(" ".join(original_abstract_sents)) abstract_pred.append(" ".join(decoded_words)) sent_res = get_sent_dist(" ".join(decoded_words), batch.original_articles[0].decode(), minimum_seq=self.args.minimum_seq) sent_counter.append( (sent_res['seen_sent'], sent_res['article_sent'])) summary_len.append(sent_res['summary_len']) summary_sent.append(sent_res['summary_sent']) summary_sent_count[sent_res['summary_sent']] += 1 article_sent.append(sent_res['article_sent']) if sent_res['avg_copied_seq_len'] is not None: avg_max_seq_len_list.append(sent_res['avg_copied_seq_len']) copied_sequence_per_sent.append( np.average( list(sent_res['counter_summary_sent_id'].values()))) copied_sequence_len.update(sent_res['counter_copied_sequence_len']) sentence_copy_id_count.update(sent_res['counter_summary_sent_id']) article_copy_id_count_tot.update( sent_res['counter_article_sent_id']) novel_counter.update(sent_res['novel_ngram_counter']) repeated_counter.update(sent_res['repeated_ngram_counter']) sent_count_fp.write( str(counter) + "\t" + str(sent_res['article_sent']) + "\t" + str(sent_res['seen_sent']) + "\n") write_for_rouge(original_abstract_sents, decoded_words, counter, self._rouge_ref_dir, self._rouge_dec_dir) batch = self.batcher.next_batch() counter += 1 if counter % 1000 == 0: print('%d example in %d sec' % (counter, time.time() - start)) start = time.time() #print('Time taken for rest: ', time.process_time() - start) if args.decode_for_subset: if counter == 1000: break print("Decoder has finished reading dataset for single_pass.") fp = open(self.stat_res_file, 'w') percentages = [ float(len(seen_sent)) / float(sent_count) for seen_sent, sent_count in sent_counter ] avg_percentage = sum(percentages) / float(len(percentages)) nosents = [len(seen_sent) for seen_sent, sent_count in sent_counter] avg_nosents = sum(nosents) / float(len(nosents)) res = dict() res['avg_percentage_seen_sent'] = avg_percentage res['avg_nosents'] = avg_nosents res['summary_len'] = summary_sent_count res['avg_summary_len'] = np.average(summary_len) res['summary_sent'] = np.average(summary_sent) res['article_sent'] = np.average(article_sent) res['avg_copied_seq_len'] = np.average(avg_max_seq_len_list) res['avg_sequences_per_sent'] = np.average(copied_sequence_per_sent) res['counter_copied_sequence_len'] = copied_sequence_len res['counter_summary_sent_id'] = sentence_copy_id_count res['counter_article_sent_id'] = article_copy_id_count_tot res['novel_ngram_counter'] = novel_counter res['repeated_ngram_counter'] = repeated_counter fp.write("Summary metrics\n") for key in res: fp.write('{}: {}\n'.format(key, res[key])) fp.write("Structures metrics\n") fp.write("Average depth of RST tree: " + str(sum(height_avg) / len(height_avg)) + "\n") fp.write("Average proportion of leaf nodes in RST tree: " + str( sum(leaf_node_proportion_avg) / len(leaf_node_proportion_avg)) + "\n") fp.write("Precision of edges latent to explicit: " + str(np.average(precision_tree_dist)) + "\n") fp.write("Recall of edges latent to explicit: " + str(np.average(recall_tree_dist)) + "\n") fp.write("Tree height counter:\n") fp.write(str(height_counter) + "\n") fp.write("Tree leaf proportion counter:") fp.write(str(leaf_nodes_counter) + "\n") if args.predict_contsel_tags: fp.write("Avg token_contsel: " + str((counts['token_consel_num_correct'] / float(counts['token_consel_num'])))) if args.predict_sent_single_head: fp.write("Avg single sent heads: " + str((counts['sent_single_heads_num_correct'] / float(counts['sent_single_heads_num'])))) if args.predict_sent_all_head: fp.write("Avg all sent heads: " + str((counts['sent_all_heads_num_correct'] / float(counts['sent_all_heads_num'])))) if args.predict_sent_all_child: fp.write("Avg all sent child: " + str((counts['sent_all_child_num_correct'] / float(counts['sent_all_child_num'])))) fp.close() sent_count_fp.close() write_to_json_file(abstract_ref, self._rouge_ref_file) write_to_json_file(abstract_pred, self._rouge_pred_file) def get_decoded_outputs(self, batch, count): #batch should have only one example enc_batch, enc_padding_token_mask, enc_padding_sent_mask, enc_doc_lens, enc_sent_lens, \ enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0, word_batch, word_padding_mask, enc_word_lens, \ enc_tags_batch, enc_sent_tags, enc_sent_token_mat, adj_mat, weighted_adj_mat, norm_adj_mat, \ parent_heads, undir_weighted_adj_mat = get_input_from_batch(batch, use_cuda, self.args) enc_adj_mat = adj_mat if args.use_weighted_annotations: if args.use_undirected_weighted_graphs: enc_adj_mat = undir_weighted_adj_mat else: enc_adj_mat = weighted_adj_mat encoder_output = self.model.encoder.forward_test( enc_batch, enc_sent_lens, enc_doc_lens, enc_padding_token_mask, enc_padding_sent_mask, word_batch, word_padding_mask, enc_word_lens, enc_tags_batch, enc_sent_token_mat, enc_adj_mat) encoder_outputs, enc_padding_mask, encoder_last_hidden, max_encoder_output, \ enc_batch_extend_vocab, token_level_sentence_scores, sent_outputs, token_scores, \ sent_scores, sent_matrix, sent_level_rep = \ self.model.get_app_outputs(encoder_output, enc_padding_token_mask, enc_padding_sent_mask, enc_batch_extend_vocab, enc_sent_token_mat) mask = enc_padding_sent_mask[0].unsqueeze(0).repeat( enc_padding_sent_mask.size(1), 1) * enc_padding_sent_mask[0].unsqueeze(1).transpose(1, 0) mask = torch.cat((enc_padding_sent_mask[0].unsqueeze(1), mask), dim=1) mat = encoder_output['sent_attention_matrix'][0][:, :] * mask structure_info = self.extract_structures( batch, encoder_output['token_attention_matrix'], mat, count, use_cuda, encoder_output['sent_score']) counts = {} predictions = {} if args.predict_contsel_tags: pred = encoder_output['token_score'][0, :, :].view(-1, 2) token_contsel_gold = enc_tags_batch[0, :].view(-1) token_contsel_prediction = torch.argmax( pred.clone().detach().requires_grad_(False), dim=1) token_contsel_prediction[ token_contsel_gold == -1] = -2 # Explicitly set masked tokens as different from value in gold token_consel_num_correct = torch.sum( token_contsel_prediction.eq(token_contsel_gold)).item() token_consel_num = torch.sum(token_contsel_gold != -1).item() predictions['token_contsel_prediction'] = token_contsel_prediction counts['token_consel_num_correct'] = token_consel_num_correct counts['token_consel_num'] = token_consel_num if args.predict_sent_single_head: pred = encoder_output['sent_single_head_scores'][0, :, :] head_labels = parent_heads[0, :].view(-1) sent_single_heads_prediction = torch.argmax( pred.clone().detach().requires_grad_(False), dim=1) sent_single_heads_prediction[ head_labels == -1] = -2 # Explicitly set masked tokens as different from value in gold sent_single_heads_num_correct = torch.sum( sent_single_heads_prediction.eq(head_labels)).item() sent_single_heads_num = torch.sum(head_labels != -1).item() predictions[ 'sent_single_heads_prediction'] = sent_single_heads_prediction counts[ 'sent_single_heads_num_correct'] = sent_single_heads_num_correct counts['sent_single_heads_num'] = sent_single_heads_num if args.predict_sent_all_head: pred = encoder_output['sent_all_head_scores'][0, :, :, :] target = adj_mat[0, :, :].permute(0, 1).view(-1) sent_all_heads_prediction = torch.argmax( pred.clone().detach().requires_grad_(False), dim=1) sent_all_heads_prediction[ target == -1] = -2 # Explicitly set masked tokens as different from value in gold sent_all_heads_num_correct = torch.sum( sent_all_heads_prediction.eq(target)).item() sent_all_heads_num = torch.sum(target != -1).item() predictions[ 'sent_all_heads_prediction'] = sent_all_heads_prediction counts['sent_all_heads_num_correct'] = sent_all_heads_num_correct counts['sent_all_heads_num'] = sent_all_heads_num if args.predict_sent_all_child: pred = encoder_output['sent_all_child_scores'][0, :, :, :] target = adj_mat[0, :, :].view(-1) sent_all_child_prediction = torch.argmax( pred.clone().detach().requires_grad_(False), dim=1) sent_all_child_prediction[ target == -1] = -2 # Explicitly set masked tokens as different from value in gold sent_all_child_num_correct = torch.sum( sent_all_child_prediction.eq(target)).item() sent_all_child_num = torch.sum(target != -1).item() predictions[ 'sent_all_child_prediction'] = sent_all_child_prediction counts['sent_all_child_num_correct'] = sent_all_child_num_correct counts['sent_all_child_num'] = sent_all_child_num results = [] steps = 0 has_summary = False beams_sorted = [None] if args.predict_summaries: has_summary = True if (args.fixed_scorer): scorer_output = self.model.module.pretrained_scorer.forward_test( enc_batch, enc_sent_lens, enc_doc_lens, enc_padding_token_mask, enc_padding_sent_mask, word_batch, word_padding_mask, enc_word_lens, enc_tags_batch) token_scores = scorer_output['token_score'] sent_scores = scorer_output['sent_score'].unsqueeze(1).repeat( 1, enc_padding_token_mask.size(2), 1, 1).view( enc_padding_token_mask.size(0), enc_padding_token_mask.size(1) * enc_padding_token_mask.size(2)) all_child, all_head = None, None if args.use_gold_annotations_for_decode: if args.use_weighted_annotations: if args.use_undirected_weighted_graphs: permuted_all_head = undir_weighted_adj_mat[:, :, :].permute( 0, 2, 1) all_head = permuted_all_head.clone() row_sums = torch.sum(permuted_all_head, dim=2, keepdim=True) all_head[row_sums.expand_as( permuted_all_head) != 0] = permuted_all_head[ row_sums.expand_as(permuted_all_head) != 0] / row_sums.expand_as(permuted_all_head)[ row_sums.expand_as(permuted_all_head) != 0] base_all_child = undir_weighted_adj_mat[:, :, :] all_child = base_all_child.clone() row_sums = torch.sum(base_all_child, dim=2, keepdim=True) all_child[row_sums.expand_as( base_all_child) != 0] = base_all_child[ row_sums.expand_as(base_all_child) != 0] / row_sums.expand_as(base_all_child)[ row_sums.expand_as(base_all_child) != 0] else: permuted_all_head = weighted_adj_mat[:, :, :].permute( 0, 2, 1) all_head = permuted_all_head.clone() row_sums = torch.sum(permuted_all_head, dim=2, keepdim=True) all_head[row_sums.expand_as( permuted_all_head) != 0] = permuted_all_head[ row_sums.expand_as(permuted_all_head) != 0] / row_sums.expand_as(permuted_all_head)[ row_sums.expand_as(permuted_all_head) != 0] base_all_child = weighted_adj_mat[:, :, :] all_child = base_all_child.clone() row_sums = torch.sum(base_all_child, dim=2, keepdim=True) all_child[row_sums.expand_as( base_all_child) != 0] = base_all_child[ row_sums.expand_as(base_all_child) != 0] / row_sums.expand_as(base_all_child)[ row_sums.expand_as(base_all_child) != 0] else: permuted_all_head = adj_mat[:, :, :].permute(0, 2, 1) all_head = permuted_all_head.clone() row_sums = torch.sum(permuted_all_head, dim=2, keepdim=True) all_head[row_sums.expand_as( permuted_all_head) != 0] = permuted_all_head[ row_sums.expand_as(permuted_all_head) != 0] / row_sums.expand_as(permuted_all_head)[ row_sums.expand_as(permuted_all_head) != 0] base_all_child = adj_mat[:, :, :] all_child = base_all_child.clone() row_sums = torch.sum(base_all_child, dim=2, keepdim=True) all_child[row_sums.expand_as(base_all_child) != 0] = base_all_child[ row_sums.expand_as(base_all_child) != 0] / row_sums.expand_as(base_all_child)[ row_sums.expand_as(base_all_child) != 0] # all_head = adj_mat[:, :, :].permute(0,2,1) + 0.00005 # row_sums = torch.sum(all_head, dim=2, keepdim=True) # all_head = all_head / row_sums # all_child = adj_mat[:, :, :] + 0.00005 # row_sums = torch.sum(all_child, dim=2, keepdim=True) # all_child = all_child / row_sums s_t_0 = self.model.reduce_state(encoder_last_hidden) if config.use_maxpool_init_ctx: c_t_0 = max_encoder_output dec_h, dec_c = s_t_0 # 1 x 2*hidden_size dec_h = dec_h.squeeze() dec_c = dec_c.squeeze() #decoder batch preparation, it has beam_size example initially everything is repeated beams = [ Beam(tokens=[self.vocab.word2id(data.START_DECODING)], log_probs=[0.0], state=(dec_h[0], dec_c[0]), context=c_t_0[0], coverage=(coverage_t_0[0] if self.args.is_coverage or self.args.bu_coverage_penalty else None)) for _ in range(args.beam_size) ] while steps < args.max_dec_steps and len(results) < args.beam_size: latest_tokens = [h.latest_token for h in beams] # cur_len = torch.stack([len(h.tokens) for h in beams]) latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \ for t in latest_tokens] y_t_1 = Variable(torch.LongTensor(latest_tokens)) if use_cuda: y_t_1 = y_t_1.cuda() all_state_h = [] all_state_c = [] all_context = [] for h in beams: state_h, state_c = h.state all_state_h.append(state_h) all_state_c.append(state_c) all_context.append(h.context) s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0)) c_t_1 = torch.stack(all_context, 0) coverage_t_1 = None if self.args.is_coverage or self.args.bu_coverage_penalty: all_coverage = [] for h in beams: all_coverage.append(h.coverage) coverage_t_1 = torch.stack(all_coverage, 0) final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder( y_t_1, s_t_1, encoder_outputs, word_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage_t_1, token_scores, sent_scores, sent_outputs, enc_sent_token_mat, all_head, all_child, sent_level_rep) if args.bu_coverage_penalty: penalty = torch.max(coverage_t, coverage_t.clone().fill_(1.0)).sum(-1) penalty -= coverage_t.size(-1) final_dist -= args.beta * penalty.unsqueeze(1).expand_as( final_dist) if args.bu_length_penalty: penalty = ((5 + steps + 1) / 6.0)**args.alpha final_dist = final_dist / penalty topk_log_probs, topk_ids = torch.topk(final_dist, args.beam_size * 2) dec_h, dec_c = s_t dec_h = dec_h.squeeze() dec_c = dec_c.squeeze() all_beams = [] num_orig_beams = 1 if steps == 0 else len(beams) for i in range(num_orig_beams): h = beams[i] state_i = (dec_h[i], dec_c[i]) context_i = c_t[i] coverage_i = (coverage_t[i] if self.args.is_coverage or self.args.bu_coverage_penalty else None) for j in range(args.beam_size * 2): # for each of the top 2*beam_size hyps: new_beam = h.extend(token=topk_ids[i, j].item(), log_prob=topk_log_probs[i, j].item(), state=state_i, context=context_i, coverage=coverage_i) all_beams.append(new_beam) beams = [] for h in self.sort_beams(all_beams): if h.latest_token == self.vocab.word2id( data.STOP_DECODING): if steps >= config.min_dec_steps: results.append(h) else: beams.append(h) if len(beams) == args.beam_size or len( results) == args.beam_size: break steps += 1 if len(results) == 0: results = beams beams_sorted = self.sort_beams(results) return has_summary, beams_sorted[ 0], predictions, counts, structure_info, undir_weighted_adj_mat