def prepro(tokenizer, d, max_len=150, max_sent_len=60): """ make sure data is not empty""" source_sents, extracts = d tokenized_sents = [ tokenizer.tokenize(source_sent.lower()) for source_sent in source_sents ] tokenized_sents = tokenized_sents[:max_sent_len] tokenized_sents = [['[CLS]'] + tokenized_sent[:max_len - 1] for tokenized_sent in tokenized_sents] tokenized_sents = [ tokenizer.convert_tokens_to_ids(tokenized_sent) for tokenized_sent in tokenized_sents ] word_num = [ len(tokenized_sent) for tokenized_sent in tokenized_sents ] tokenized_sents = [ tokenizer.convert_tokens_to_ids(tokenized_sent) for tokenized_sent in tokenized_sents ] abs_sents = tokenize(None, extracts) art_sents = tokenize(None, source_sents) return (art_sents, tokenized_sents, word_num), abs_sents
def prepro(tokenizer, d, max_len=512): """ make sure data is not empty""" source_sents, extracts = d tokenized_sents = [ tokenizer.tokenize(source_sent.lower()) for source_sent in source_sents ] tokenized_sents = [ tokenized_sent + ['[SEP]'] for tokenized_sent in tokenized_sents ] tokenized_sents[0] = ['[CLS]'] + tokenized_sents[0] word_num = [ len(tokenized_sent) for tokenized_sent in tokenized_sents ] truncated_word_num = [] total_count = 0 for num in word_num: if total_count + num < max_len: truncated_word_num.append(num) else: truncated_word_num.append(512 - total_count) break total_count += num tokenized_sents = list(concat(tokenized_sents))[:max_len] tokenized_sents = tokenizer.convert_tokens_to_ids(tokenized_sents) abs_sents = tokenize(None, extracts) art_sents = tokenize(None, source_sents) return (art_sents, tokenized_sents, truncated_word_num), abs_sents
def coll(batch): art_batch, abs_batch, query_batch = unzip(batch) query_batch_list = list(query_batch) art_sents = list(filter(bool, map(tokenize(None), art_batch))) abs_sents = list(filter(bool, map(tokenize(None), abs_batch))) queries = list( filter( bool, map(tokenize(None), [[query] for query in query_batch_list]))) return art_sents, abs_sents, queries
def coll(batch): def is_good_data(d): """ make sure data is not empty""" source_sents, extracts = d return source_sents and extracts art_batch, abs_batch = unzip(batch) art_batch, abs_batch = list( zip(*list(filter(is_good_data, zip(art_batch, abs_batch))))) art_sents = list(filter(bool, map(tokenize(None), art_batch))) abs_sents = list(filter(bool, map(tokenize(None), abs_batch))) return art_sents, abs_sents
def prepro(tokenizer, d, max_len=1024, stride=256): """ make sure data is not empty""" source_sents, extracts = d tokenized_sents = [ tokenizer.tokenize(source_sent.lower()) for source_sent in source_sents ] tokenized_sents = [['[CLS]'] + tokenized_sent for tokenized_sent in tokenized_sents] tokenized_sents = [ tokenizer.convert_tokens_to_ids(tokenized_sent) for tokenized_sent in tokenized_sents ] word_num = [ len(tokenized_sent) for tokenized_sent in tokenized_sents ] truncated_word_num = [] total_count = 0 for num in word_num: if total_count + num < max_len: truncated_word_num.append(num) else: truncated_word_num.append(max_len - total_count) break total_count += num tokenized_sents = list(concat(tokenized_sents))[:max_len] tokenized_sents_lists = [tokenized_sents[:BERT_MAX_LEN]] length = len(tokenized_sents) - BERT_MAX_LEN i = 1 while length > 0: tokenized_sents_lists.append( tokenized_sents[(i * BERT_MAX_LEN - stride):((i + 1) * BERT_MAX_LEN - stride)]) i += 1 length -= (BERT_MAX_LEN - stride) abs_sents = tokenize(None, extracts) art_sents = tokenize(None, source_sents) return (art_sents, tokenized_sents_lists, truncated_word_num), abs_sents
def main(article_path, model_dir, batch_size, beam_size, diverse, max_len, cuda): with open(join(model_dir, 'meta.json')) as f: meta = json.loads(f.read()) if meta['net_args']['abstractor'] is None: # NOTE: if no abstractor is provided then # the whole model would be extractive summarization assert beam_size == 1 abstractor = identity else: if beam_size == 1: abstractor = Abstractor(join(model_dir, 'abstractor'), max_len, cuda) else: abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len, cuda) extractor = RLExtractor(model_dir, cuda=cuda) with open(article_path) as f: raw_article_batch = f.readlines() tokenized_article_batch = map(tokenize(None), raw_article_batch) ext_arts = [] ext_inds = [] for raw_art_sents in tokenized_article_batch: print(raw_art_sents) ext = extractor(raw_art_sents)[:-1] # exclude EOE if not ext: # use top-5 if nothing is extracted # in some rare cases rnn-ext does not extract at all ext = list(range(5))[:len(raw_art_sents)] else: ext = [i.item() for i in ext] ext_inds += [(len(ext_arts), len(ext))] ext_arts += [raw_art_sents[i] for i in ext] if beam_size > 1: all_beams = abstractor(ext_arts, beam_size, diverse) dec_outs = rerank_mp(all_beams, ext_inds) else: dec_outs = abstractor(ext_arts) # assert i == batch_size*i_debug for j, n in ext_inds: decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]] print(decoded_sents)
def coll(batch): art_batch, abs_batch, i_batch = unzip(batch) art_sents = list(filter(bool, map(tokenize(None), art_batch))) abs_sents = list(filter(bool, map(tokenize(None), abs_batch))) return art_sents, abs_sents, list(i_batch)
def coll(batch): art_batch, topics, abs_batch = unzip(batch) #art_batch, topics, abs_batch, topic_label = unzip(batch) art_sents = list(filter(bool, map(tokenize(None), art_batch))) abs_sents = list(filter(bool, map(tokenize(None), abs_batch))) return art_sents, topics, abs_sents
def coll(batch): split_token = '<split>' pad = 0 art_batch, abs_batch, all_clusters = unzip(batch) art_sents = [] abs_sents = [] def is_good_data(d): """ make sure data is not empty""" source_sents, extracts = d return source_sents and extracts art_batch, abs_batch = list( zip(*list(filter(is_good_data, zip(art_batch, abs_batch))))) art_sents = list(filter(bool, map(tokenize(None), art_batch))) abs_sents = list(filter(bool, map(tokenize(None), abs_batch))) inputs = [] # merge cluster for art_sent, clusters in zip(art_sents, all_clusters): cluster_words = [] cluster_wpos = [] cluster_spos = [] for cluster in clusters: scluster_word = [] scluster_wpos = [] scluster_spos = [] for mention in cluster: if len(mention['text'].strip().split(' ')) == len( list( range(mention['position'][3] + 1, mention['position'][4] + 1))): scluster_word += mention['text'].lower().strip().split( ' ') scluster_wpos += list( range(mention['position'][3] + 1, mention['position'][4] + 1)) scluster_spos += [ mention['position'][0] + 1 for _ in range( len(mention['text'].strip().split(' '))) ] scluster_word.append(split_token) scluster_wpos.append(pad) scluster_spos.append(pad) else: sent_num = mention['position'][0] word_start = mention['position'][3] word_end = mention['position'][4] # if word_end > 99: # word_end = 99 if sent_num > len(art_sent) - 1: print('bad cluster') continue scluster_word += art_sent[sent_num][ word_start:word_end] scluster_wpos += list(range(word_start, word_end)) scluster_spos += [ mention['position'][0] + 1 for _ in range(word_start + 1, word_end + 1) ] scluster_word.append(split_token) scluster_wpos.append(pad) scluster_spos.append(pad) if scluster_word != []: scluster_word.pop() scluster_wpos.pop() scluster_spos.pop() cluster_words.append(scluster_word) cluster_wpos.append(scluster_wpos) cluster_spos.append(scluster_spos) if len(scluster_word) != len(scluster_wpos): print(scluster_word) print(scluster_wpos) print('cluster:', cluster) if len(scluster_word) != len(scluster_spos): print(scluster_word) print(scluster_spos) print('cluster:', cluster) assert len(scluster_word) == len(scluster_spos) and len( scluster_spos) == len(scluster_wpos) new_clusters = (cluster_words, cluster_wpos, cluster_spos) inputs.append((art_sent, new_clusters)) assert len(inputs) == len(abs_sents) return inputs, abs_sents
def decode(save_path, model_dir, split, batch_size, beam_size, diverse, max_len, cuda, bart=False, clip=-1, tri_block=False): start = time() # setup model with open(join(model_dir, 'meta.json')) as f: meta = json.loads(f.read()) if meta['net_args']['abstractor'] is None: # NOTE: if no abstractor is provided then # the whole model would be extractive summarization assert beam_size == 1 abstractor = identity else: if beam_size == 1: abstractor = Abstractor(join(model_dir, 'abstractor'), max_len, cuda) else: abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len, cuda) extractor = RLExtractor(model_dir, cuda=cuda) # setup loader def coll(batch): articles = list(filter(bool, batch)) return articles dataset = DecodeDataset(split) n_data = len(dataset) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=coll) # prepare save paths and logs os.makedirs(join(save_path, 'output')) dec_log = {} dec_log['abstractor'] = meta['net_args']['abstractor'] dec_log['extractor'] = meta['net_args']['extractor'] dec_log['rl'] = True dec_log['split'] = split dec_log['beam'] = beam_size dec_log['diverse'] = diverse with open(join(save_path, 'log.json'), 'w') as f: json.dump(dec_log, f, indent=4) # Decoding i = 0 with torch.no_grad(): for i_debug, raw_article_batch in enumerate(loader): # raw_article_batch tokenized_article_batch = map(tokenize(None), [r[0] for r in raw_article_batch]) tokenized_abs_batch = map(tokenize(None), [r[1] for r in raw_article_batch]) ext_arts = [] ext_inds = [] for raw_art_sents, raw_abs_sents in zip(tokenized_article_batch, tokenized_abs_batch): ext, raw_art_sents = extractor(raw_art_sents, raw_abs_sents=raw_abs_sents) # print(raw_art_sen/ts) ext = ext[:-1] # exclude EOE # print(ext) if tri_block: _pred = [] _ids = [] for j in ext: if (j >= len(raw_art_sents)): continue candidate = " ".join(raw_art_sents[j]).strip() if (not _block_tri(candidate, _pred)): _pred.append(candidate) _ids.append(j) else: continue if (len(_pred) == 3): break ext = _ids # print(_pred) if clip > 0 and len( ext) > clip: #ADDED FOR CLIPPING, CHANGE BACK # print("hi", clip) ext = ext[0:clip] if not ext: # use top-5 if nothing is extracted # in some rare cases rnn-ext does not extract at all ext = list(range(5))[:len(raw_art_sents)] else: ext = [i.item() for i in ext] # print(ext) ext_inds += [(len(ext_arts), len(ext))] ext_arts += [raw_art_sents[i] for i in ext] if bart: # print("hi") dec_outs = get_bart_summaries(ext_arts, tokenizer, bart_model, beam_size=beam_size) else: if beam_size > 1: all_beams = abstractor(ext_arts, beam_size, diverse) dec_outs = rerank_mp(all_beams, ext_inds) else: dec_outs = abstractor(ext_arts) # print(dec_outs, i, i_debug) assert i == batch_size * i_debug for j, n in ext_inds: decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]] with open(join(save_path, 'output/{}.dec'.format(i)), 'w') as f: f.write(make_html_safe('\n'.join(decoded_sents))) i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i / n_data * 100, timedelta(seconds=int(time() - start))), end='') print()
def decode(save_path, abs_dir, split, batch_size, max_len, cuda, min_len): start = time() # setup model if abs_dir is None: # NOTE: if no abstractor is provided then # the whole model would be extractive summarization raise Exception('abs directory none!') else: #abstractor = Abstractor(abs_dir, max_len, cuda) abstractor = BeamAbstractor(abs_dir, max_len, cuda, min_len, reverse=args.reverse) bert = abstractor._bert if bert: tokenizer = abstractor._tokenizer if bert: import logging logging.basicConfig(level=logging.ERROR) # if args.docgraph or args.paragraph: # abstractor = BeamAbstractorGAT(abs_dir, max_len, cuda, min_len, reverse=args.reverse) # setup loader def coll(batch): articles = list(filter(bool, batch)) return articles dataset = AbsDecodeDataset(split) n_data = len(dataset) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=coll) save_path = os.path.join(save_path, split) os.makedirs(save_path) # prepare save paths and logs dec_log = {} dec_log['abstractor'] = (None if abs_dir is None else json.load( open(join(abs_dir, 'meta.json')))) dec_log['rl'] = False dec_log['split'] = split dec_log['beam'] = 5 # greedy decoding only beam_size = 5 with open(join(save_path, 'log.json'), 'w') as f: json.dump(dec_log, f, indent=4) os.makedirs(join(save_path, 'output')) # Decoding i = 0 length = 0 with torch.no_grad(): for i_debug, raw_article_batch in enumerate(loader): if bert: tokenized_article_batch = map( tokenize_keepcase(args.max_input), raw_article_batch) else: tokenized_article_batch = map(tokenize(args.max_input), raw_article_batch) ext_arts = [] ext_inds = [] beam_inds = [] pre_abs = list(tokenized_article_batch) pre_abs = [article[0] for article in pre_abs] for j in range(len(pre_abs)): beam_inds += [(len(beam_inds), 1)] all_beams = abstractor(pre_abs, beam_size, diverse=1.0) dec_outs = rerank_mp(all_beams, beam_inds) for dec_out in dec_outs: if bert: text = ''.join(' '.join(dec_out).split(' ')) dec_out = bytearray([ tokenizer.byte_decoder[c] for c in text ]).decode('utf-8', errors=tokenizer.errors) dec_out = [dec_out] dec_out = sent_tokenize(' '.join(dec_out)) ext = [sent.split(' ') for sent in dec_out] ext_inds += [(len(ext_arts), len(ext))] ext_arts += ext dec_outs = ext_arts assert i == batch_size * i_debug for j, n in ext_inds: decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]] with open(join(save_path, 'output/{}.dec'.format(i)), 'w') as f: f.write(make_html_safe('\n'.join(decoded_sents))) i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i / n_data * 100, timedelta(seconds=int(time() - start))), end='') length += len(decoded_sents) print('average summary length:', length / i)
def prepro_subgraph_nobert(batch, max_sent_len=100, max_sent=60, node_max_len=30, key='InSalientSent', adj_type='edge_as_node'): source_sents, nodes, edges, subgraphs, paras, extracts = batch tokenized_sents = tokenize(max_sent_len, source_sents)[:max_sent] tokenized_sents_2 = tokenize(None, source_sents)[:max_sent] tokenized_article = list(concat(tokenized_sents_2)) cleaned_extracts = list( filter(lambda e: e < len(tokenized_sents), extracts)) max_len = len(tokenized_article) # tokenized_sents = [tokenized_sent + ['[SEP]'] for tokenized_sent in tokenized_sents] # tokenized_sents[0] = ['[CLS]'] + tokenized_sents[0] sent_align_para = [] last_idx = 0 for sent in range(len(tokenized_sents)): flag = False for _idx, para in enumerate(paras): if sent in para: sent_align_para.append(_idx) last_idx = _idx flag = True break if not flag: sent_align_para.append(last_idx) assert len(sent_align_para) == len(tokenized_sents) sent_align_para.append(last_idx + 1) segment_feat_para = [ sent_align_para[_sid] + 1 if sent_align_para[_sid] < MAX_FREQ - 1 else MAX_FREQ - 1 for _sid, sent in enumerate(tokenized_sents_2) for word in sent ] segment_feat_sent = [[ sent_align_para[_sid] + 1 if sent_align_para[_sid] < MAX_FREQ - 1 else MAX_FREQ - 1 for word in sent ] for _sid, sent in enumerate(tokenized_sents)] sent_align_para = [[_] for _ in sent_align_para] word_num = [len(tokenized_sent) for tokenized_sent in tokenized_sents] truncated_word_num = word_num # find out of range and useless nodes other_nodes = set() oor_nodes = [] # out of range nodes will not included in the graph word_freq_feat, word_inpara_feat, sent_freq_feat, sent_inpara_freq_feat = create_word_freq_in_para_feat( paras, tokenized_sents, tokenized_article) assert len(word_freq_feat) == len(tokenized_article) and len( word_inpara_feat) == len(tokenized_article) for _id, content in nodes.items(): words = [ pos for mention in content['content'] for pos in mention['word_pos'] if pos != -1 ] words = [word for word in words if word < max_len] if len(words) != 0: other_nodes.add(_id) else: oor_nodes.append(_id) activated_nodes = set() for _id, content in edges.items(): if content['content']['arg1'] not in oor_nodes and content['content'][ 'arg2'] not in oor_nodes: words = content['content']['word_pos'] new_words = [ word for word in words if word > -1 and word < max_len ] if len(new_words) > 0: activated_nodes.add(content['content']['arg1']) activated_nodes.add(content['content']['arg2']) oor_nodes.extend(list(other_nodes - activated_nodes)) # process nodes sorted_nodes = sorted(nodes.items(), key=lambda x: int(x[0].split('_')[1])) nodewords = [] nodefreq = [] sum_worthy = [] id2node = {} ii = 0 for _id, content in sorted_nodes: if _id not in oor_nodes: words = [ pos for mention in content['content'] for pos in mention['word_pos'] if pos != -1 ] words = [word for word in words if word < max_len] words = words[:node_max_len] sum_worthy.append(content[key]) if len(words) != 0: nodewords.append(words) nodefreq.append(len(content['content'])) id2node[_id] = ii ii += 1 else: oor_nodes.append(_id) if len(nodewords) == 0: # print('warning! no nodes in this sample') nodewords = [[0], [2]] nodefreq.extend([1, 1]) sum_worthy.extend([0, 0]) nodelength = [len(words) for words in nodewords] # process edges acticated_nodes = set() triples = [] edge_freq = [] relations = [] sum_worthy_edges = [] id2edge = {} sorted_edges = sorted(edges.items(), key=lambda x: int(x[0].split('_')[1])) ii = 0 for _id, content in sorted_edges: if content['content']['arg1'] not in oor_nodes and content['content'][ 'arg2'] not in oor_nodes: words = content['content']['word_pos'] new_words = [ word for word in words if word > -1 and word < max_len ] new_words = new_words[:node_max_len] if len(new_words) > 0: node1 = id2node[content['content']['arg1']] node2 = id2node[content['content']['arg2']] sum_worthy_edges.append(content[key]) if adj_type == 'edge_up': nodewords[node1].extend(new_words) elif adj_type == 'edge_down': nodewords[node2].extend(new_words) edge = int(_id.split('_')[1]) edge_freq.append(1) triples.append([node1, ii, node2]) acticated_nodes.add(content['content']['arg1']) acticated_nodes.add(content['content']['arg2']) id2edge[_id] = ii ii += 1 relations.append(new_words) if len(relations) == 0: # print('warning! no edges in this sample') relations = [[1]] edge_freq = [1] triples = [[0, 0, 1]] sum_worthy_edges.extend([0]) node_lists = [] edge_lists = [] triples = [] if max_sent is None: max_sent = 9999 for _sgid, subgraph in enumerate(subgraphs): try: paraid = paras[_sgid][0] except: paraid = 0 if type(paraid) != type(max_sent): paraid = 0 if paraid > max_sent - 1: continue if subgraph == []: node_lists.append([]) triples.append([]) edge_lists.append([]) else: node_list = set() triple = [] edge_list = [] eidx = [] for _triple in subgraph: if _triple[0] not in oor_nodes and _triple[ 2] not in oor_nodes and id2edge.__contains__( _triple[1]): node_list.add(id2node[_triple[0]]) node_list.add(id2node[_triple[2]]) eidx.append(_triple[1]) node_list = list(sorted(node_list)) for _triple in subgraph: if _triple[0] not in oor_nodes and _triple[ 2] not in oor_nodes and id2edge.__contains__( _triple[1]): idx1 = node_list.index(id2node[_triple[0]]) idx2 = node_list.index(id2node[_triple[2]]) _idxe = id2edge[_triple[1]] idxe_in_subgraph = eidx.index(_triple[1]) edge_list.append(_idxe) triple.append([idx1, idxe_in_subgraph, idx2]) triples.append(triple) node_lists.append(node_list) edge_lists.append(edge_list) if len(node_lists) == 0: node_lists.append([]) triples.append([]) edge_lists.append([]) rlength = [len(words) for words in relations] nodefreq = [ freq if freq < MAX_FREQ - 1 else MAX_FREQ - 1 for freq in nodefreq ] if adj_type == 'edge_as_node': node_num = len(nodewords) nodewords = nodewords + relations nodefreq = nodefreq + edge_freq nodelength = nodelength + rlength sum_worthy = sum_worthy + sum_worthy_edges for i in range(len(triples)): node_lists[i] = node_lists[i] + [ edge + node_num for edge in edge_lists[i] ] gold_dec_selection_label = [0 for i in range(len(node_lists))] for sent in cleaned_extracts: for i, para in enumerate(paras): if sent in para: gold_dec_selection_label[i] = 1 return tokenized_article, truncated_word_num, ( nodewords, sum_worthy, gold_dec_selection_label), (relations, triples, node_lists, sent_align_para, segment_feat_sent, segment_feat_para, nodefreq, word_freq_feat, word_inpara_feat, sent_freq_feat, sent_inpara_freq_feat)
def decode(save_path, abs_dir, ext_dir, split, batch_size, max_len, cuda, trans=False): start = time() # setup model if abs_dir is None: # NOTE: if no abstractor is provided then # the whole model would be extractive summarization abstractor = identity else: abstractor = Abstractor(abs_dir, max_len, cuda) if ext_dir is None: # NOTE: if no abstractor is provided then # it would be the lead-N extractor extractor = lambda art_sents: list(range(len(art_sents)))[:MAX_ABS_NUM] else: extractor = Extractor(ext_dir, max_ext=MAX_ABS_NUM, cuda=cuda) # setup loader def coll(batch): articles = list(filter(bool, batch)) return articles dataset = DecodeDataset(split) n_data = len(dataset) loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=coll ) # prepare save paths and logs for i in range(MAX_ABS_NUM): os.makedirs(join(save_path, 'output_{}'.format(i))) # os.makedirs(join(save_path, 'output')) dec_log = {} dec_log['abstractor'] = (None if abs_dir is None else json.load(open(join(abs_dir, 'meta.json')))) dec_log['extractor'] = (None if ext_dir is None else json.load(open(join(ext_dir, 'meta.json')))) dec_log['rl'] = False dec_log['split'] = split dec_log['beam'] = 1 # greedy decoding only with open(join(save_path, 'log.json'), 'w') as f: json.dump(dec_log, f, indent=4) # Decoding i = 0 with torch.no_grad(): for i_debug, raw_article_batch in enumerate(loader): if trans: tokenized_article_batch = raw_article_batch # else: tokenized_article_batch = map(tokenize(None), raw_article_batch) ext_arts = [] ext_inds = [] for raw_art_sents in tokenized_article_batch: if trans: ext, batch = extractor(raw_art_sents) art_sents = batch.src_str[0] # print(ext, [x.nonzero(as_tuple=True)[0] for x in batch.src_sent_labels]) for k, idx in enumerate([ext]): _pred = [] _ids = [] if (len(batch.src_str[k]) == 0): continue for j in idx[:min(len(ext), len(batch.src_str[k]))]: if (j >= len(batch.src_str[k])): continue candidate = batch.src_str[k][j].strip() if (not _block_tri(candidate, _pred)): _pred.append(candidate) _ids.append(j) else: continue if (len(_pred) == 3): break # print(ext, _ids, [x.nonzero(as_tuple=True)[0] for x in batch.src_sent_labels], list(map(lambda i: art_sents[i], ext))) ext = _ids ext_inds += [(len(ext_arts), len(ext))] ext_arts += list(map(lambda i: art_sents[i], ext)) else: ext = extractor(raw_art_sents) ext_inds += [(len(ext_arts), len(ext))] ext_arts += list(map(lambda i: raw_art_sents[i], ext)) dec_outs = abstractor(ext_arts) # print(dec_outs) assert i == batch_size*i_debug for j, n in ext_inds: if trans: decoded_sents = dec_outs[j:j+n] else: decoded_sents = [' '.join(dec) for dec in dec_outs[j:j+n]] for k, dec_str in enumerate(decoded_sents): with open(join(save_path, 'output_{}/{}.dec'.format(k, i)), 'w') as f: f.write(make_html_safe(dec_str)) #f.write(make_html_safe('\n'.join(decoded_sents))) i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i/n_data*100, timedelta(seconds=int(time()-start)) ), end='') # if i_debug == 1: # break print()
def decode(save_path, model_dir, split, batch_size, beam_size, diverse, max_len, cuda): start = time() # setup model with open(join(model_dir, 'meta.json')) as f: meta = json.loads(f.read()) if meta['net_args']['abstractor'] is None: # NOTE: if no abstractor is provided then # the whole model would be extractive summarization assert beam_size == 1 abstractor = identity else: if beam_size == 1: abstractor = Abstractor(join(model_dir, 'abstractor'), max_len, cuda) else: abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len, cuda) extractor = RLExtractor(model_dir, cuda=cuda) # setup loader def coll(batch): articles = list(filter(bool, batch)) return articles dataset = DecodeDataset(split) n_data = len(dataset) loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=coll ) # prepare save paths and logs os.makedirs(join(save_path, 'output')) dec_log = {} dec_log['abstractor'] = meta['net_args']['abstractor'] dec_log['extractor'] = meta['net_args']['extractor'] dec_log['rl'] = True dec_log['split'] = split dec_log['beam'] = beam_size dec_log['diverse'] = diverse with open(join(save_path, 'log.json'), 'w') as f: json.dump(dec_log, f, indent=4) # Decoding i = 0 with torch.no_grad(): for i_debug, raw_article_batch in enumerate(loader): tokenized_article_batch = map(tokenize(None), raw_article_batch) ext_arts = [] ext_inds = [] for raw_art_sents in tokenized_article_batch: ext = extractor(raw_art_sents)[:-1] # exclude EOE if not ext: # use top-5 if nothing is extracted # in some rare cases rnn-ext does not extract at all ext = list(range(5))[:len(raw_art_sents)] else: ext = [i.item() for i in ext] ext_inds += [(len(ext_arts), len(ext))] ext_arts += [raw_art_sents[i] for i in ext] if beam_size > 1: all_beams = abstractor(ext_arts, beam_size, diverse) dec_outs = rerank_mp(all_beams, ext_inds) else: dec_outs = abstractor(ext_arts) assert i == batch_size*i_debug for j, n in ext_inds: decoded_sents = [' '.join(dec) for dec in dec_outs[j:j+n]] with open(join(save_path, 'output/{}.dec'.format(i)), 'w') as f: f.write(make_html_safe('\n'.join(decoded_sents))) i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i/n_data*100, timedelta(seconds=int(time()-start)) ), end='') print()
def decode(save_path, model_dir, split, batch_size, beam_size, diverse, max_len, cuda, sc, min_len): start = time() # setup model with open(join(model_dir, 'meta.json')) as f: meta = json.loads(f.read()) if meta['net_args']['abstractor'] is None: #if not meta['net_args'].__contains__('abstractor'): # NOTE: if no abstractor is provided then # the whole model would be extractive summarization assert beam_size == 1 abstractor = identity else: if beam_size == 1: abstractor = Abstractor(join(model_dir, 'abstractor'), max_len, cuda) else: abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len, cuda, min_len) if sc: extractor = SCExtractor(model_dir, cuda=cuda) else: extractor = RLExtractor(model_dir, cuda=cuda) #check if use bert try: _bert = extractor._net._bert except: _bert = False print('no bert arg:') if _bert: tokenizer = BertTokenizer.from_pretrained( 'bert-large-uncased-whole-word-masking') print('bert tokenizer loaded') # setup loader def coll(batch): articles = list(filter(bool, batch)) return articles dataset = DecodeDataset(split) n_data = len(dataset) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=coll) # prepare save paths and logs os.makedirs(join(save_path, 'output')) dec_log = {} dec_log['abstractor'] = meta['net_args']['abstractor'] dec_log['extractor'] = meta['net_args']['extractor'] dec_log['rl'] = True dec_log['split'] = split dec_log['beam'] = beam_size dec_log['diverse'] = diverse with open(join(save_path, 'log.json'), 'w') as f: json.dump(dec_log, f, indent=4) # Decoding if sc: i = 0 length = 0 with torch.no_grad(): for i_debug, raw_article_batch in enumerate(loader): tokenized_article_batch = map(tokenize(None), raw_article_batch) ext_arts = [] ext_inds = [] if _bert: for raw_art_sents, raw_art in zip(tokenized_article_batch, raw_article_batch): tokenized_sents = [ tokenizer.tokenize(source_sent.lower()) for source_sent in raw_art ] tokenized_sents = [ tokenized_sent + ['[SEP]'] for tokenized_sent in tokenized_sents ] tokenized_sents[0] = ['[CLS]'] + tokenized_sents[0] word_num = [ len(tokenized_sent) for tokenized_sent in tokenized_sents ] truncated_word_num = [] total_count = 0 for num in word_num: if total_count + num < MAX_LEN_BERT: truncated_word_num.append(num) else: truncated_word_num.append(MAX_LEN_BERT - total_count) break total_count += num tokenized_sents = list( concat(tokenized_sents))[:MAX_LEN_BERT] tokenized_sents = tokenizer.convert_tokens_to_ids( tokenized_sents) art_sents = tokenize(None, raw_art) _input = (art_sents, tokenized_sents, truncated_word_num) ext = extractor(_input)[:] # exclude EOE if not ext: # use top-3 if nothing is extracted # in some rare cases rnn-ext does not extract at all ext = list(range(3))[:len(raw_art_sents)] else: ext = [i for i in ext] ext_inds += [(len(ext_arts), len(ext))] ext_arts += [raw_art_sents[i] for i in ext] else: for raw_art_sents in tokenized_article_batch: ext = extractor(raw_art_sents)[:] # exclude EOE if not ext: # use top-5 if nothing is extracted # in some rare cases rnn-ext does not extract at all ext = list(range(5))[:len(raw_art_sents)] else: ext = [i for i in ext] ext_inds += [(len(ext_arts), len(ext))] ext_arts += [raw_art_sents[i] for i in ext] if beam_size > 1: all_beams = abstractor(ext_arts, beam_size, diverse) dec_outs = rerank_mp(all_beams, ext_inds) else: dec_outs = abstractor(ext_arts) assert i == batch_size * i_debug for j, n in ext_inds: decoded_sents = [ ' '.join(dec) for dec in dec_outs[j:j + n] ] with open(join(save_path, 'output/{}.dec'.format(i)), 'w') as f: f.write(make_html_safe('\n'.join(decoded_sents))) i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i / n_data * 100, timedelta(seconds=int(time() - start))), end='') length += len(decoded_sents) else: i = 0 length = 0 with torch.no_grad(): for i_debug, raw_article_batch in enumerate(loader): tokenized_article_batch = map(tokenize(None), raw_article_batch) ext_arts = [] ext_inds = [] for raw_art_sents in tokenized_article_batch: ext = extractor(raw_art_sents)[:-1] # exclude EOE if not ext: # use top-5 if nothing is extracted # in some rare cases rnn-ext does not extract at all ext = list(range(5))[:len(raw_art_sents)] else: ext = [i.item() for i in ext] ext_inds += [(len(ext_arts), len(ext))] ext_arts += [raw_art_sents[i] for i in ext] if beam_size > 1: all_beams = abstractor(ext_arts, beam_size, diverse) dec_outs = rerank_mp(all_beams, ext_inds) else: dec_outs = abstractor(ext_arts) assert i == batch_size * i_debug for j, n in ext_inds: decoded_sents = [ ' '.join(dec) for dec in dec_outs[j:j + n] ] with open(join(save_path, 'output/{}.dec'.format(i)), 'w') as f: f.write(make_html_safe('\n'.join(decoded_sents))) i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i / n_data * 100, timedelta(seconds=int(time() - start))), end='') length += len(decoded_sents) print('average summary length:', length / i)
def decode_graph(save_path, model_dir, split, batch_size, beam_size, diverse, max_len, cuda, sc, min_len, docgraph, paragraph): start = time() # setup model with open(join(model_dir, 'meta.json')) as f: meta = json.loads(f.read()) if meta['net_args']['abstractor'] is None: #if not meta['net_args'].__contains__('abstractor'): # NOTE: if no abstractor is provided then # the whole model would be extractive summarization assert beam_size == 1 abstractor = identity else: if beam_size == 1: abstractor = Abstractor(join(model_dir, 'abstractor'), max_len, cuda) else: abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len, cuda, min_len=min_len) print('docgraph:', docgraph) extractor = SCExtractor(model_dir, cuda=cuda, docgraph=docgraph, paragraph=paragraph) adj_type = extractor._net._adj_type bert = extractor._net._bert if bert: tokenizer = extractor._net._bert try: with open( '/data/luyang/process-nyt/bert_tokenizaiton_aligns/robertaalign-base-cased.pkl', 'rb') as f: align = pickle.load(f) except FileNotFoundError: with open( '/data2/luyang/process-nyt/bert_tokenizaiton_aligns/robertaalign-base-cased.pkl', 'rb') as f: align = pickle.load(f) try: with open( '/data/luyang/process-cnn-dailymail/bert_tokenizaiton_aligns/robertaalign-base-cased.pkl', 'rb') as f: align2 = pickle.load(f) except FileNotFoundError: with open( '/data2/luyang/process-cnn-dailymail/bert_tokenizaiton_aligns/robertaalign-base-cased.pkl', 'rb') as f: align2 = pickle.load(f) align.update(align2) # setup loader def coll(batch): batch = list(filter(bool, batch)) return batch dataset = DecodeDatasetGAT(split, args.key) n_data = len(dataset) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=coll) # prepare save paths and logs os.makedirs(join(save_path, 'output')) dec_log = {} dec_log['abstractor'] = meta['net_args']['abstractor'] dec_log['extractor'] = meta['net_args']['extractor'] dec_log['rl'] = True dec_log['split'] = split dec_log['beam'] = beam_size dec_log['diverse'] = diverse with open(join(save_path, 'log.json'), 'w') as f: json.dump(dec_log, f, indent=4) # Decoding i = 0 length = 0 sent_selected = 0 with torch.no_grad(): for i_debug, raw_input_batch in enumerate(loader): raw_article_batch, nodes, edges, paras, subgraphs = zip( *raw_input_batch) if bert: art_sents = [[ tokenizer.tokenize(source_sent) for source_sent in source_sents ] for source_sents in raw_article_batch] for _i in range(len(art_sents)): art_sents[_i][0] = [tokenizer.bos_token] + art_sents[_i][0] art_sents[_i][-1] = art_sents[_i][-1] + [ tokenizer.eos_token ] truncated_word_nums = [] word_nums = [[len(sent) for sent in art_sent] for art_sent in art_sents] for word_num in word_nums: truncated_word_num = [] total_count = 0 for num in word_num: if total_count + num < args.max_dec_word: truncated_word_num.append(num) else: truncated_word_num.append(args.max_dec_word - total_count) break total_count += num truncated_word_nums.append(truncated_word_num) sources = [ list(concat(art_sent))[:args.max_dec_word] for art_sent in art_sents ] else: tokenized_article_batch = map(tokenize(None), raw_article_batch) #processed_clusters = map(preproc(list(tokenized_article_batch), clusters)) #processed_clusters = list(zip(*processed_clusters)) ext_arts = [] ext_inds = [] pre_abs = [] beam_inds = [] if bert: for raw_art_sents, source, art_sent, word_num, raw_nodes, raw_edges, raw_paras, raw_subgraphs in zip( raw_article_batch, sources, art_sents, truncated_word_nums, nodes, edges, paras, subgraphs): processed_nodes = prepro_rl_graph_bert( align, raw_art_sents, source, art_sent, args.max_dec_word, raw_nodes, raw_edges, raw_paras, raw_subgraphs, adj_type, docgraph) _input = (raw_art_sents, source) + processed_nodes + (word_num, ) ext = extractor(_input)[:] sent_selected += len(ext) if not ext: # use top-3 if nothing is extracted # in some rare cases rnn-ext does not extract at all ext = list(range(3))[:len(raw_art_sents)] else: ext = [i for i in ext] ext_art = list(map(lambda i: raw_art_sents[i], ext)) pre_abs.append([word for sent in ext_art for word in sent]) beam_inds += [(len(beam_inds), 1)] else: for raw_art_sents, raw_nodes, raw_edges, raw_paras, raw_subgraphs in zip( tokenized_article_batch, nodes, edges, paras, subgraphs): processed_nodes = prepro_rl_graph(raw_art_sents, raw_nodes, raw_edges, raw_paras, raw_subgraphs, adj_type, docgraph) _input = (raw_art_sents, ) + processed_nodes ext = extractor(_input)[:] # exclude EOE sent_selected += len(ext) if not ext: # use top-3 if nothing is extracted # in some rare cases rnn-ext does not extract at all ext = list(range(3))[:len(raw_art_sents)] else: ext = [i for i in ext] ext_art = list(map(lambda i: raw_art_sents[i], ext)) pre_abs.append([word for sent in ext_art for word in sent]) beam_inds += [(len(beam_inds), 1)] if beam_size > 1: # all_beams = abstractor(ext_arts, beam_size, diverse) # dec_outs = rerank_mp(all_beams, ext_inds) all_beams = abstractor(pre_abs, beam_size, diverse=1.0) dec_outs = rerank_mp(all_beams, beam_inds) else: dec_outs = abstractor(pre_abs) for dec_out in dec_outs: dec_out = sent_tokenize(' '.join(dec_out)) ext = [sent.split(' ') for sent in dec_out] ext_inds += [(len(ext_arts), len(ext))] ext_arts += ext dec_outs = ext_arts assert i == batch_size * i_debug for j, n in ext_inds: decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]] with open(join(save_path, 'output/{}.dec'.format(i)), 'w') as f: f.write(make_html_safe('\n'.join(decoded_sents))) i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i / n_data * 100, timedelta(seconds=int(time() - start))), end='') length += len(decoded_sents) print('average summary length:', length / i) print('average sentence selected:', sent_selected)
def decode(args, predict=False): # save_path = args.path batch_size = args.batch beam_size = args.beam diverse = args.div start = time() extractor = args.extractor abstractor = args.abstractor # setup model text = '' # setup loader def coll(batch): articles = list(filter(bool, batch)) return articles if not predict: dataset = DecodeDataset(args) n_data = len(dataset) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=coll) else: n_data = 1 loader = clean_and_split(args.text) loader = [[[' '.join(mecab_tokenizer(line)) for line in loader]]] text = '\n'.join(loader[0][0]) i = 0 #print(text) with torch.no_grad(): for i_debug, raw_article_batch in enumerate(loader): tokenized_article_batch = map(tokenize(None), raw_article_batch) ext_arts = [] ext_inds = [] for raw_art_sents in tokenized_article_batch: ext = extractor(raw_art_sents)[:-1] # exclude EOE if not ext: # use top-5 if nothing is extracted # in some rare cases rnn-ext does not extract at all ext = list(range(5))[:len(raw_art_sents)] else: ext = [i.item() for i in ext] ext_inds += [(len(ext_arts), len(ext))] ext_arts += [raw_art_sents[i] for i in ext] if beam_size > 1: #print(ext_arts) all_beams = abstractor(ext_arts, beam_size, diverse) dec_outs = rerank_mp(all_beams, ext_inds) else: dec_outs = abstractor(ext_arts) assert i == batch_size * i_debug source_text = [''.join(sent) for sent in ext_arts] for j, n in ext_inds: decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]] decoded_sents = decoded_sents[:20] # with open(join(save_path, 'output/{}.dec'.format(i)), # 'w') as f: # f.write(make_html_safe('\n'.join(decoded_sents))) result = make_html_safe('\n\n'.join(decoded_sents)) i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i / n_data * 100, timedelta(seconds=int(time() - start))), end='') print() return text, result, source_text
def coll(tokenizer, align, max_src_len, adj_type, batch): art_batch, abs_batch, all_nodes, all_edges, all_subgraphs, all_paras = unzip( batch) def is_good_data(d): """ make sure data is not empty""" source_sents, extracts, nodes, edges, subgraphs, paras = d return source_sents and extracts and nodes and subgraphs and paras art_batch, abs_batch, all_nodes, all_edges, all_subgraphs, all_paras = list( zip(*list( filter( is_good_data, zip(art_batch, abs_batch, all_nodes, all_edges, all_subgraphs, all_paras))))) old_sources = art_batch art_sents = [[ tokenizer.tokenize(source_sent) for source_sent in source_sents ] for source_sents in art_batch] for _i in range(len(art_sents)): art_sents[_i][0] = [tokenizer.bos_token] + art_sents[_i][0] art_sents[_i][-1] = art_sents[_i][-1] + [tokenizer.eos_token] truncated_word_nums = [] word_nums = [[len(sent) for sent in art_sent] for art_sent in art_sents] for word_num in word_nums: truncated_word_num = [] total_count = 0 for num in word_num: if total_count + num < max_src_len: truncated_word_num.append(num) else: truncated_word_num.append(max_src_len - total_count) break total_count += num truncated_word_nums.append(truncated_word_num) sources = [ list(concat(art_sent))[:max_src_len] for art_sent in art_sents ] raw_art_sents = list(filter(bool, map(tokenize(None), art_batch))) abs_sents = list(filter(bool, map(tokenize(None), abs_batch))) max_sents = list( map(count_max_sent(max_source_num=max_src_len), art_sents)) inputs = [] # merge cluster nodewords, nodelength, nodefreq, sum_worthy, triples, relations = \ list(zip(*[process_nodes_bert(align, node, edge, len(source) - 1, max_sent, key='InSalientSent', adj_type=adj_type, source_sent=sent, paras=para, subgraphs=subgraph, docgraph=docgraph, source=source) for node, edge, sent, para, subgraph, source, max_sent in zip(all_nodes, all_edges, old_sources, all_paras, all_subgraphs, sources, max_sents)])) # for art_sent, nodes, edges, subgraphs, paras in zip(art_sents, all_nodes, all_edges, all_subgraphs, all_paras): # max_len = len(list(concat(art_sent))) # # nodewords, nodelength, nodefreq, sum_worthy, triples, relations, sent_node_aligns = process_nodes(nodes, edges, max_len, max_sent_num=len(list(art_sent)), key=key, adj_type=adj_type) # if paragraph: # nodewords, node_lists, nodefreq, sum_worthy, triples, relations = process_subgraphs( # nodes, edges, subgraphs, paras, max_len, max_sent=len(list(art_sent)), # key=key, adj_type=adj_type # ) # sent_align_para = [] # last_idx = 0 # for sent in range(len(art_sent)): # flag = False # for _idx, para in enumerate(paras): # if sent in para: # sent_align_para.append([_idx]) # last_idx = _idx # flag = True # break # if not flag: # sent_align_para.append([last_idx]) # assert len(sent_align_para) == len(art_sent) # sent_align_para.append([last_idx + 1]) # if docgraph: # inputs.append((art_sent, nodewords, nodefreq, triples, relations, sent_node_aligns)) # elif paragraph: # inputs.append((art_sent, nodewords, nodefreq, triples, # relations, sent_align_para, node_lists)) # else: # raise Exception('wrong graph type') if docgraph: inputs = list( zip(raw_art_sents, sources, nodewords, nodefreq, triples, relations, truncated_word_nums)) else: node_lists = nodelength inputs = list( zip(raw_art_sents, sources, nodewords, nodefreq, triples, relations, node_lists, truncated_word_nums)) assert len(inputs) == len(abs_sents) return inputs, abs_sents
def coll(key, adj_type, batch): split_token = '<split>' pad = 0 art_batch, abs_batch, all_nodes, all_edges, all_subgraphs, all_paras = unzip( batch) def is_good_data(d): """ make sure data is not empty""" source_sents, extracts, nodes, edges, subgraphs, paras = d return source_sents and extracts and nodes and subgraphs and paras art_batch, abs_batch, all_nodes, all_edges, all_subgraphs, all_paras = list( zip(*list( filter( is_good_data, zip(art_batch, abs_batch, all_nodes, all_edges, all_subgraphs, all_paras))))) art_sents = list(filter(bool, map(tokenize(None), art_batch))) abs_sents = list(filter(bool, map(tokenize(None), abs_batch))) inputs = [] # merge cluster for art_sent, nodes, edges, subgraphs, paras in zip( art_sents, all_nodes, all_edges, all_subgraphs, all_paras): max_len = len(list(concat(art_sent))) _, word_inpara_freq_feat, _, sent_inpara_freq_feat = create_word_freq_in_para_feat( paras, art_sent, list(concat(art_sent))) nodewords, nodelength, nodefreq, sum_worthy, triples, relations, sent_node_aligns = process_nodes( nodes, edges, max_len, max_sent_num=len(list(art_sent)), key=key, adj_type=adj_type) if paragraph: nodewords, node_lists, nodefreq, sum_worthy, triples, relations = process_subgraphs( nodes, edges, subgraphs, paras, max_len, max_sent=len(list(art_sent)), key=key, adj_type=adj_type) sent_align_para = [] last_idx = 0 for sent in range(len(art_sent)): flag = False for _idx, para in enumerate(paras): if sent in para: sent_align_para.append([_idx]) last_idx = _idx flag = True break if not flag: sent_align_para.append([last_idx]) assert len(sent_align_para) == len(art_sent) sent_align_para.append([last_idx + 1]) if docgraph: inputs.append((art_sent, nodewords, nodefreq, word_inpara_freq_feat, sent_inpara_freq_feat, triples, relations, sent_node_aligns)) elif paragraph: inputs.append( (art_sent, nodewords, nodefreq, word_inpara_freq_feat, sent_inpara_freq_feat, triples, relations, sent_align_para, node_lists)) else: raise Exception('wrong graph type') assert len(inputs) == len(abs_sents) return inputs, abs_sents
def decode(save_path, abs_dir, ext_dir, split, batch_size, max_len, cuda): start = time() # setup model if abs_dir is None: # NOTE: if no abstractor is provided then # the whole model would be extractive summarization abstractor = identity else: abstractor = Abstractor(abs_dir, max_len, cuda) if ext_dir is None: # NOTE: if no abstractor is provided then # it would be the lead-N extractor extractor = lambda art_sents: list(range(len(art_sents)))[:MAX_ABS_NUM] else: extractor = Extractor(ext_dir, max_ext=MAX_ABS_NUM, cuda=cuda) # setup loader def coll(batch): articles = list(filter(bool, batch)) return articles dataset = DecodeDataset(split) n_data = len(dataset) loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=coll ) # prepare save paths and logs for i in range(MAX_ABS_NUM): os.makedirs(join(save_path, 'output_{}'.format(i))) dec_log = {} dec_log['abstractor'] = (None if abs_dir is None else json.load(open(join(abs_dir, 'meta.json')))) dec_log['extractor'] = (None if ext_dir is None else json.load(open(join(ext_dir, 'meta.json')))) dec_log['rl'] = False dec_log['split'] = split dec_log['beam'] = 1 # greedy decoding only with open(join(save_path, 'log.json'), 'w') as f: json.dump(dec_log, f, indent=4) # Decoding i = 0 with torch.no_grad(): for i_debug, raw_article_batch in enumerate(loader): tokenized_article_batch = map(tokenize(None), raw_article_batch) ext_arts = [] ext_inds = [] for raw_art_sents in tokenized_article_batch: ext = extractor(raw_art_sents) ext_inds += [(len(ext_arts), len(ext))] ext_arts += list(map(lambda i: raw_art_sents[i], ext)) dec_outs = abstractor(ext_arts) assert i == batch_size*i_debug for j, n in ext_inds: decoded_sents = [' '.join(dec) for dec in dec_outs[j:j+n]] for k, dec_str in enumerate(decoded_sents): with open(join(save_path, 'output_{}/{}.dec'.format(k, i)), 'w') as f: f.write(make_html_safe(dec_str)) i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i/n_data*100, timedelta(seconds=int(time()-start)) ), end='') print()
def coll(batch): art_batch, abs_batch, extracted = unzip(batch) art_sents = list(filter(bool, map(tokenize(None), art_batch))) abs_sents = list(filter(bool, map(tokenize(None), abs_batch))) extracted = list(filter(bool, extracted)) return art_sents, abs_sents, extracted
def decode(save_path, model_dir, split, batch_size, beam_size, diverse, max_len, cuda): start = time() # setup model with open(join(model_dir, 'meta.json')) as f: meta = json.loads(f.read()) if meta['net_args']['abstractor'] is None: # NOTE: if no abstractor is provided then # the whole model would be extractive summarization assert beam_size == 1 abstractor = identity else: if beam_size == 1: abstractor = Abstractor(join(model_dir, 'abstractor'), max_len, cuda) else: abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len, cuda) extractor = RLExtractor(model_dir, cuda=cuda) # setup loader def coll(batch): articles, abstract, extracted = unzip(batch) articles = list(filter(bool, articles)) abstract = list(filter(bool, abstract)) extracted = list(filter(bool, extracted)) return articles, abstract, extracted dataset = DecodeDataset(split) n_data = len(dataset[0]) # article sentence loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=coll ) # prepare save paths and logs if os.path.exists(join(save_path, 'output')): pass else: os.makedirs(join(save_path, 'output')) dec_log = {} dec_log['abstractor'] = meta['net_args']['abstractor'] dec_log['extractor'] = meta['net_args']['extractor'] dec_log['rl'] = True dec_log['split'] = split dec_log['beam'] = beam_size dec_log['diverse'] = diverse with open(join(save_path, 'log.json'), 'w') as f: json.dump(dec_log, f, indent=4) file_path = os.path.join(save_path, 'Attention') act_path = os.path.join(save_path, 'Actions') header = "index,rouge_score1,rouge_score2,"+\ "rouge_scorel,dec_sent_nums,abs_sent_nums,doc_sent_nums,doc_words_nums,"+\ "ext_words_nums, abs_words_nums, diff,"+\ "recall, precision, less_rewrite, preserve_action, rewrite_action, each_actions,"+\ "top3AsAns, top3AsGold, any_top2AsAns, any_top2AsGold,true_rewrite,true_preserve\n" if not os.path.exists(file_path): print('create dir:{}'.format(file_path)) os.makedirs(file_path) if not os.path.exists(act_path): print('create dir:{}'.format(act_path)) os.makedirs(act_path) with open(join(save_path,'_statisticsDecode.log.csv'),'w') as w: w.write(header) # Decoding i = 0 with torch.no_grad(): for i_debug, (raw_article_batch, raw_abstract_batch, extracted_batch) in enumerate(loader): tokenized_article_batch = map(tokenize(None), raw_article_batch) tokenized_abstract_batch = map(tokenize(None), raw_abstract_batch) token_nums_batch = list(map(token_nums(None), raw_article_batch)) ext_nums = [] ext_arts = [] ext_inds = [] rewrite_less_rouge = [] dec_outs_act = [] ext_acts = [] abs_collections = [] ext_collections = [] # 抽句子 for ind, (raw_art_sents, abs_sents) in enumerate(zip(tokenized_article_batch ,tokenized_abstract_batch)): (ext, (state, act_dists)), act = extractor(raw_art_sents) # exclude EOE extracted_state = state[extracted_batch[ind]] attn = torch.softmax(state.mm(extracted_state.transpose(1,0)),dim=-1) # (_, abs_state), _ = extractor(abs_sents) # exclude EOE def plot_actDist(actons, nums): print('indiex: {} distribution ...'.format(nums)) # Write MDP State Attention weight matrix file_name = os.path.join(act_path, '{}.attention.pdf'.format(nums)) pdf_pages = PdfPages(file_name) plot_attention(actons.cpu().numpy(), name='{}-th artcle'.format(nums), X_label=list(range(len(raw_art_sents))), Y_label=list(range(len(ext))), dirpath=save_path, pdf_page=pdf_pages,action=True) pdf_pages.close() # plot_actDist(torch.stack(act_dists, dim=0), nums=ind+i) def plot_attn(): print('indiex: {} write_attention_pdf ...'.format(i + ind)) # Write MDP State Attention weight matrix file_name = os.path.join(file_path, '{}.attention.pdf'.format(i+ind)) pdf_pages = PdfPages(file_name) plot_attention(attn.cpu().numpy(), name='{}-th artcle'.format(i+ind), X_label=extracted_batch[ind],Y_label=list(range(len(raw_art_sents))), dirpath=save_path, pdf_page=pdf_pages) pdf_pages.close() # plot_attn() ext = ext[:-1] act = act[:-1] if not ext: # use top-5 if nothing is extracted # in some rare cases rnn-ext does not extract at all ext = list(range(5))[:len(raw_art_sents)] act = list([1]*5)[:len(raw_art_sents)] else: ext = [i.item() for i in ext] act = [i.item() for i in act] ext_nums.append(ext) ext_inds += [(len(ext_arts), len(ext))] # [(0,5),(5,7),(7,3),...] ext_arts += [raw_art_sents[k] for k in ext] ext_acts += [k for k in act] # 計算累計的句子 ext_collections += [sum(ext_arts[ext_inds[-1][0]:ext_inds[-1][0]+k+1],[]) for k in range(ext_inds[-1][1])] abs_collections += [sum(abs_sents[:k+1],[]) if k<len(abs_sents) else sum(abs_sents[0:len(abs_sents)],[]) for k in range(ext_inds[-1][1])] if beam_size > 1: # do n times abstract all_beams = abstractor(ext_arts, beam_size, diverse) dec_outs = rerank_mp(all_beams, ext_inds) dec_collections = [[sum(dec_outs[pos[0]:pos[0]+k+1],[]) for k in range(pos[1])] for pos in ext_inds] dec_collections = [x for sublist in dec_collections for x in sublist] for index, chooser in enumerate(ext_acts): if chooser == 0: dec_outs_act += [dec_outs[index]] else: dec_outs_act += [ext_arts[index]] assert len(ext_collections)==len(dec_collections)==len(abs_collections) for ext, dec, abss, act in zip(ext_collections, dec_collections, abs_collections, ext_acts): # for each sent in extracted digest # All abstract mapping rouge_before_rewriten = compute_rouge_n(ext, abss, n=1) rouge_after_rewriten = compute_rouge_n(dec, abss, n=1) diff_ins = rouge_before_rewriten - rouge_after_rewriten rewrite_less_rouge.append(diff_ins) else: # do 1st abstract dec_outs = abstractor(ext_arts) dec_collections = [[sum(dec_outs[pos[0]:pos[0]+k+1],[]) for k in range(pos[1])] for pos in ext_inds] dec_collections = [x for sublist in dec_collections for x in sublist] for index, chooser in enumerate(ext_acts): if chooser == 0: dec_outs_act += [dec_outs[index]] else: dec_outs_act += [ext_arts[index]] # dec_outs_act = dec_outs # dec_outs_act = ext_arts assert len(ext_collections)==len(dec_collections)==len(abs_collections) for ext, dec, abss, act in zip(ext_collections, dec_collections, abs_collections, ext_acts): # for each sent in extracted digest # All abstract mapping rouge_before_rewriten = compute_rouge_n(ext, abss, n=1) rouge_after_rewriten = compute_rouge_n(dec, abss, n=1) diff_ins = rouge_before_rewriten - rouge_after_rewriten rewrite_less_rouge.append(diff_ins) assert i == batch_size*i_debug for iters, (j, n) in enumerate(ext_inds): do_right_rewrite = sum([1 for rouge, action in zip(rewrite_less_rouge[j:j+n], ext_acts[j:j+n]) if rouge<0 and action==0]) do_right_preserve = sum([1 for rouge, action in zip(rewrite_less_rouge[j:j+n], ext_acts[j:j+n]) if rouge>=0 and action==1]) decoded_words_nums = [len(dec) for dec in dec_outs_act[j:j+n]] ext_words_nums = [token_nums_batch[iters][x] for x in range(len(token_nums_batch[iters])) if x in ext_nums[iters]] # 皆取extracted label # decoded_sents = [raw_article_batch[iters][x] for x in extracted_batch[iters]] # 統計數據 [START] decoded_sents = [' '.join(dec) for dec in dec_outs_act[j:j+n]] rouge_score1 = compute_rouge_n(' '.join(decoded_sents),' '.join(raw_abstract_batch[iters]),n=1) rouge_score2 = compute_rouge_n(' '.join(decoded_sents),' '.join(raw_abstract_batch[iters]),n=2) rouge_scorel = compute_rouge_l(' '.join(decoded_sents),' '.join(raw_abstract_batch[iters])) dec_sent_nums = len(decoded_sents) abs_sent_nums = len(raw_abstract_batch[iters]) doc_sent_nums = len(raw_article_batch[iters]) doc_words_nums = sum(token_nums_batch[iters]) ext_words_nums = sum(ext_words_nums) abs_words_nums = sum(decoded_words_nums) label_recall = len(set(ext_nums[iters]) & set(extracted_batch[iters])) / len(extracted_batch[iters]) label_precision = len(set(ext_nums[iters]) & set(extracted_batch[iters])) / len(ext_nums[iters]) less_rewrite = rewrite_less_rouge[j+n-1] dec_one_action_num = sum(ext_acts[j:j+n]) dec_zero_action_num = n - dec_one_action_num ext_indices = '_'.join([str(i) for i in ext_nums[iters]]) top3 = set([0,1,2]) <= set(ext_nums[iters]) top3_gold = set([0,1,2]) <= set(extracted_batch[iters]) # Any Top 2 top2 = set([0,1]) <= set(ext_nums[iters]) or set([1,2]) <= set(ext_nums[iters]) or set([0,2]) <= set(ext_nums[iters]) top2_gold = set([0,1]) <= set(extracted_batch[iters]) or set([1,2]) <= set(extracted_batch[iters]) or set([0,2]) <= set(extracted_batch[iters]) with open(join(save_path,'_statisticsDecode.log.csv'),'a') as w: w.write('{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n'.format(i,rouge_score1, rouge_score2, rouge_scorel, dec_sent_nums, abs_sent_nums, doc_sent_nums, doc_words_nums, ext_words_nums,abs_words_nums,(ext_words_nums - abs_words_nums), label_recall, label_precision, less_rewrite, dec_one_action_num, dec_zero_action_num, ext_indices, top3, top3_gold, top2, top2_gold,do_right_rewrite,do_right_preserve)) # 統計數據 END with open(join(save_path, 'output/{}.dec'.format(i)), 'w') as f: decoded_sents = [i for i in decoded_sents if i!=''] if len(decoded_sents) > 0: f.write(make_html_safe('\n'.join(decoded_sents))) else: f.write('') i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i/n_data*100, timedelta(seconds=int(time()-start)) ), end='') print()
def prepro_gat_nobert(batch, max_sent_len=100, max_sent=60, node_max_len=30, key='summary_worthy', adj_type='concat_triple'): source_sents, nodes, edges, paras = batch tokenized_sents = tokenize(max_sent_len, source_sents)[:max_sent] tokenized_sents_2 = tokenize(None, source_sents)[:max_sent] tokenized_article = list(concat(tokenized_sents_2)) max_len = len(tokenized_article) # tokenized_sents = [tokenized_sent + ['[SEP]'] for tokenized_sent in tokenized_sents] # tokenized_sents[0] = ['[CLS]'] + tokenized_sents[0] word_num = [len(tokenized_sent) for tokenized_sent in tokenized_sents] truncated_word_num = word_num # find out of range and useless nodes other_nodes = set() oor_nodes = [] # out of range nodes will not included in the graph word_freq_feat, word_inpara_feat, sent_freq_feat, sent_inpara_freq_feat = create_word_freq_in_para_feat( paras, tokenized_sents, tokenized_article) assert len(word_freq_feat) == len(tokenized_article) and len( word_inpara_feat) == len(tokenized_article) for _id, content in nodes.items(): words = [ pos for mention in content['content'] for pos in mention['word_pos'] if pos != -1 ] words = [word for word in words if word < max_len] if len(words) != 0: other_nodes.add(_id) else: oor_nodes.append(_id) activated_nodes = set() for _id, content in edges.items(): if content['content']['arg1'] not in oor_nodes and content['content'][ 'arg2'] not in oor_nodes: words = content['content']['word_pos'] new_words = [ word for word in words if word > -1 and word < max_len ] if len(new_words) > 0: activated_nodes.add(content['content']['arg1']) activated_nodes.add(content['content']['arg2']) oor_nodes.extend(list(other_nodes - activated_nodes)) # process nodes sorted_nodes = sorted(nodes.items(), key=lambda x: int(x[0].split('_')[1])) nodewords = [] nodefreq = [] nodeinsent = [] nodetype = [] sum_worthy = [] id2node = {} ii = 0 for _id, content in sorted_nodes: if _id not in oor_nodes: words = [ pos for mention in content['content'] for pos in mention['word_pos'] if pos != -1 ] words = [word for word in words if word < max_len] words = words[:node_max_len] sum_worthy.append(content[key]) if len(words) != 0: nodewords.append(words) nodefreq.append(len(content['content'])) nodetype.append(1) nodeinsent.append([ mention['sent_pos'] for mention in content['content'] if mention['sent_pos'] < len(tokenized_sents) ]) id2node[_id] = ii ii += 1 else: oor_nodes.append(_id) if len(nodewords) == 0: # print('warning! no nodes in this sample') nodewords = [[0], [2]] nodefreq.extend([1, 1]) nodeinsent.extend([[0], [0]]) nodetype.extend([1, 1]) sum_worthy.extend([0, 0]) nodelength = [len(words) for words in nodewords] # process edges acticated_nodes = set() triples = [] edge_freq = [] edgeinsent = [] edgetype = [] relations = [] sum_worthy_edges = [] sorted_edges = sorted(edges.items(), key=lambda x: int(x[0].split('_')[1])) ii = 0 for _id, content in sorted_edges: if content['content']['arg1'] not in oor_nodes and content['content'][ 'arg2'] not in oor_nodes: words = content['content']['word_pos'] new_words = [ word for word in words if word > -1 and word < max_len ] new_words = new_words[:node_max_len] if len(new_words) > 0: node1 = id2node[content['content']['arg1']] node2 = id2node[content['content']['arg2']] sum_worthy_edges.append(content[key]) try: sent_pos = [content['content']['sent_pos']] except KeyError: sent_pos = [ content['content']['arg1_original'][0]['sent_pos'] ] if adj_type == 'edge_up': nodewords[node1].extend(new_words) elif adj_type == 'edge_down': nodewords[node2].extend(new_words) edge = int(_id.split('_')[1]) edge_freq.append(1) edgeinsent.append(sent_pos) edgetype.append(2) triples.append([node1, ii, node2]) acticated_nodes.add(content['content']['arg1']) acticated_nodes.add(content['content']['arg2']) ii += 1 relations.append(new_words) if len(relations) == 0: # print('warning! no edges in this sample') relations = [[1]] triples = [[0, 0, 1]] edgeinsent.append([0]) edge_freq = [1] edgetype.append(2) sum_worthy_edges.extend([0]) nodefreq = [ freq if freq < MAX_FREQ - 1 else MAX_FREQ - 1 for freq in nodefreq ] rlength = [len(words) for words in relations] if adj_type == 'edge_as_node': nodewords = nodewords + relations nodelength = nodelength + rlength sum_worthy = sum_worthy + sum_worthy_edges nodefreq = nodefreq + edge_freq nodetype = nodetype + edgetype nodeinsent = nodeinsent + edgeinsent sent_node_aligns = create_sent_node_align(nodeinsent, len(tokenized_sents)) return tokenized_article, truncated_word_num, (nodewords, nodelength, sum_worthy, nodefreq, word_freq_feat, word_inpara_feat, sent_freq_feat, sent_inpara_freq_feat, sent_node_aligns), \ (relations, rlength, triples)
def decode(save_path, model_dir, split, batch_size, beam_size, diverse, max_len, cuda): start = time() # setup model with open(join(model_dir, 'meta.json')) as f: meta = json.loads(f.read()) if meta['net_args']['abstractor'] is None: # NOTE: if no abstractor is provided then # the whole model would be extractive summarization assert beam_size == 1 abstractor = lambda x,y:x else: if beam_size == 1: abstractor = Abstractor(join(model_dir, 'abstractor'), max_len, cuda) else: print('BEAM') abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len, cuda) extractor = RLExtractor(model_dir, cuda=cuda) # setup loader def coll(batch): articles = list(filter(bool, batch)) return articles dataset = DecodeDataset(split) n_data = len(dataset) loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=coll ) # prepare save paths and logs try: os.makedirs(join(save_path, 'output')) except: pass dec_log = {} dec_log['abstractor'] = meta['net_args']['abstractor'] dec_log['extractor'] = meta['net_args']['extractor'] dec_log['rl'] = True dec_log['split'] = split dec_log['beam'] = beam_size dec_log['diverse'] = diverse with open(join(save_path, 'log.json'), 'w') as f: json.dump(dec_log, f, indent=4) # Decoding i = 0 total_leng = 0 total_num = 0 with torch.no_grad(): for i_debug, data_batch in enumerate(loader): raw_article_batch, sent_label_batch = tuple(map(list, unzip(data_batch))) tokenized_article_batch = map(tokenize(None), raw_article_batch) #ext_arts = [] ext_inds = [] dirty = [] ext_sents = [] masks = [] for raw_art_sents, sent_labels in zip(tokenized_article_batch, sent_label_batch): ext = extractor(raw_art_sents, sent_labels) # exclude EOE tmp_size = min(max_dec_edu, len(ext) - 1) #total_leng += sum([len(e) -1 for e in ext[:-1]]) #total_num += len(ext) - 1 #print(tmp_size, len(ext) - 1) ext_inds += [(len(ext_sents), tmp_size)] tmp_stop = ext[-1][-1].item() tmp_truncate = tmp_stop - 1 str_arts = list(map(lambda x: ' '.join(x), raw_art_sents)) for idx in ext[:tmp_size]: t, m = rl_edu_to_sentence(str_arts, idx) total_leng += len(t) total_num += 1 assert len(t) == len(m) if t == []: assert len(idx) == 1 id = idx[0].item() if id == tmp_truncate: dirty.append(len(ext_sents)) ext_sents.append(label) masks.append(label_mask) else: if idx[-1].item() != tmp_stop: ext_sents.append(t) masks.append(m) #ext_arts += [raw_art_sents[i] for i in ext] #print(ext_sents) #print(masks) #print(dirty) #exit(0) if beam_size > 1: #print(ext_sents) #print(masks) all_beams = abstractor(ext_sents, masks, beam_size, diverse) print('rerank') dec_outs = rerank_mp(all_beams, ext_inds) for d in dirty: dec_outs[d] = [] # TODO:!!!!!!!!!!! else: dec_outs = abstractor(ext_sents, masks) for d in dirty: dec_outs[d] = [] assert i == batch_size*i_debug for j, n in ext_inds: decoded_sents = [' '.join(dec) for dec in dec_outs[j:j+n]] with open(join(save_path, 'output/{}.dec'.format(i)), 'w') as f: f.write(make_html_safe('\n'.join(decoded_sents))) if i % 100 == 0: print(total_leng / total_num) i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i/n_data*100, timedelta(seconds=int(time()-start)) ), end='') print()
def decode(save_path, model_dir, split, batch_size, beam_size, diverse, max_len, cuda): start = time() if beam_size == 1: abstractor = Abstractor(join(model_dir, 'abstractor'), max_len, cuda) else: abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len, cuda) # setup loader def coll(batch): articles = list(filter(bool, batch)) articles = [" ".join(article) for article in articles] return articles dataset = DecodeDataset(args.data_path, split) n_data = len(dataset) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=coll) # prepare save paths and logs os.makedirs(join(save_path, 'output')) dec_log = {} dec_log['rl'] = True dec_log['split'] = split dec_log['beam'] = beam_size dec_log['diverse'] = diverse with open(join(save_path, 'log.json'), 'w') as f: json.dump(dec_log, f, indent=4) # Decoding i = 0 with torch.no_grad(): for i_debug, raw_article_batch in enumerate(loader): tokenized_article_batch = tokenize(1000, raw_article_batch) batch_size = len(tokenized_article_batch) ext_inds = [] for num in range(batch_size): ext_inds += [(num, 1)] if beam_size > 1: all_beams = abstractor(tokenized_article_batch, beam_size, diverse) dec_outs = rerank_mp(all_beams, ext_inds) else: dec_outs = abstractor(tokenized_article_batch) assert i == batch_size * i_debug for index in range(batch_size): decoded_sents = [ ' '.join(dec.split(",")) for dec in dec_outs[index] ] with open(join(save_path, 'output/{}.dec'.format(i)), 'w') as f: f.write(make_html_safe(' '.join(decoded_sents))) i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i / n_data * 100, timedelta(seconds=int(time() - start))), end='') print()
def decode(args): save_path = args.path model_dir = args.model_dir batch_size = args.batch beam_size = args.beam diverse = args.div max_len = args.max_dec_word cuda = args.cuda start = time() # setup model with open(join(model_dir, 'meta.json')) as f: meta = json.loads(f.read()) if meta['net_args']['abstractor'] is None: # NOTE: if no abstractor is provided then # the whole model would be extractive summarization assert beam_size == 1 abstractor = identity else: if beam_size == 1: abstractor = Abstractor(join(model_dir, 'abstractor'), max_len, cuda) else: abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len, cuda) extractor = RLExtractor(model_dir, cuda=cuda) # setup loader def coll(batch): articles = list(filter(bool, batch)) return articles dataset = DecodeDataset(args) n_data = len(dataset) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=coll) # prepare save paths and logs if not os.path.exists(join(save_path, 'output')): os.makedirs(join(save_path, 'output')) dec_log = {} dec_log['abstractor'] = meta['net_args']['abstractor'] dec_log['extractor'] = meta['net_args']['extractor'] dec_log['rl'] = True dec_log['split'] = args.mode dec_log['beam'] = beam_size dec_log['diverse'] = diverse with open(join(save_path, 'log.json'), 'w') as f: json.dump(dec_log, f, indent=4) # Decoding i = 0 with torch.no_grad(): for i_debug, raw_article_batch in enumerate(loader): tokenized_article_batch = map(tokenize(None), raw_article_batch) ext_arts = [] ext_inds = [] for raw_art_sents in tokenized_article_batch: ext = extractor(raw_art_sents)[:-1] # exclude EOE if not ext: # use top-5 if nothing is extracted # in some rare cases rnn-ext does not extract at all ext = list(range(5))[:len(raw_art_sents)] else: ext = [i.item() for i in ext] ext_inds += [(len(ext_arts), len(ext))] ext_arts += [raw_art_sents[i] for i in ext] if beam_size > 1: all_beams = abstractor(ext_arts, beam_size, diverse) dec_outs = rerank_mp(all_beams, ext_inds) else: dec_outs = abstractor(ext_arts) assert i == batch_size * i_debug for j, n in ext_inds: decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]] decoded_sents = decoded_sents[:20] with open(join(save_path, 'output/{}.dec'.format(i)), 'w') as f: f.write(make_html_safe('\n'.join(decoded_sents))) i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i / n_data * 100, timedelta(seconds=int(time() - start))), end='') print()
def decode(save_path, save_file, model_dir, split, batch_size, beam_size, diverse, max_len, cuda): start = time() # setup model with open(join(model_dir, 'meta.json')) as f: meta = json.loads(f.read()) if meta['net_args']['abstractor'] is None: # NOTE: if no abstractor is provided then # the whole model would be extractive summarization assert beam_size == 1 abstractor = identity else: if beam_size == 1: abstractor = Abstractor(join(model_dir, 'abstractor'), max_len, cuda) else: abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len, cuda) extractor = RLExtractor(model_dir, cuda=cuda) # setup loader def coll(batch): articles = list(filter(bool, batch)) return articles dataset = DecodeDataset(split) n_data = len(dataset) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=coll) # prepare save paths and logs os.makedirs(join(save_path, 'output')) dec_log = {} dec_log['abstractor'] = meta['net_args']['abstractor'] dec_log['extractor'] = meta['net_args']['extractor'] dec_log['rl'] = True dec_log['split'] = split dec_log['beam'] = beam_size dec_log['diverse'] = diverse with open(join(save_path, 'log.json'), 'w') as f: json.dump(dec_log, f, indent=4) # Decoding i = 0 with torch.no_grad(): for i_debug, raw_article_batch in enumerate(loader): tokenized_article_batch = map(tokenize(None), raw_article_batch) ext_arts = [] ext_inds = [] for raw_art_sents in tokenized_article_batch: ext = extractor(raw_art_sents)[:-1] # exclude EOE if not ext: # use top-5 if nothing is extracted # in some rare cases rnn-ext does not extract at all ext = list(range(5))[:len(raw_art_sents)] else: ext = [i.item() for i in ext] ext_inds += [(len(ext_arts), len(ext))] ext_arts += [raw_art_sents[i] for i in ext] if beam_size > 1: all_beams = abstractor(ext_arts, beam_size, diverse) dec_outs = rerank_mp(all_beams, ext_inds) else: dec_outs = abstractor(ext_arts) assert i == batch_size * i_debug for j, n in ext_inds: decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]] with open(join(save_path, 'output/{}.dec'.format(i)), 'w') as f: f.write(make_html_safe('\n'.join(decoded_sents))) i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i / n_data * 100, timedelta(seconds=int(time() - start))), end='') print() #not needed for cnn/dailymail dataset probably f = open(save_file, "w") summaries_files = os.listdir(join(save_path, 'output')) n = len(summaries_files) summaries_list = [""] * n for fname in summaries_files: num = int(fname.replace(".dec", "")) f_local = open(join(save_path, "output", fname)) summaries_list[num] = f_local.read().replace("\n", " ") f_local.close() assert (len(summaries_list) == n) f.write("\n".join(summaries_list)) f.close()
def test(args, split): ext_dir = args.path ckpts = sort_ckpt(ext_dir) # setup loader def coll(batch): articles = list(filter(bool, batch)) return articles dataset = DecodeDataset(split) n_data = len(dataset) loader = DataLoader( dataset, batch_size=args.batch, shuffle=False, num_workers=4, collate_fn=coll ) # decode and evaluate top 5 models os.mkdir(join(args.path, 'decode')) os.mkdir(join(args.path, 'ROUGE')) for i in range(min(5, len(ckpts))): print('Start loading checkpoint {} !'.format(ckpts[i])) cur_ckpt = torch.load( join(ext_dir, 'ckpt/{}'.format(ckpts[i])) )['state_dict'] extractor = Extractor(ext_dir, cur_ckpt, args.emb_type, cuda=args.cuda) save_path = join(args.path, 'decode/{}'.format(ckpts[i])) os.mkdir(save_path) # decoding ext_list = [] cur_idx = 0 start = time() with torch.no_grad(): for raw_article_batch in loader: tokenized_article_batch = map(tokenize(None, args.emb_type), raw_article_batch) for raw_art_sents in tokenized_article_batch: ext_idx = extractor(raw_art_sents) ext_list.append(ext_idx) cur_idx += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( cur_idx, n_data, cur_idx/n_data*100, timedelta(seconds=int(time()-start)) ), end='') print() # write files for file_idx, ext_ids in enumerate(ext_list): dec = [] data_path = join(DATA_DIR, '{}/{}.json'.format(split, file_idx)) with open(data_path) as f: data = json.loads(f.read()) n_ext = 2 if data['source'] == 'CNN' else 3 n_ext = min(n_ext, len(data['article'])) for j in range(n_ext): sent_idx = ext_ids[j] dec.append(data['article'][sent_idx]) with open(join(save_path, '{}.dec'.format(file_idx)), 'w') as f: for sent in dec: print(sent, file=f) # evaluate current model print('Starting evaluating ROUGE !') dec_path = save_path ref_path = join(DATA_DIR, 'refs/{}'.format(split)) ROUGE = eval_rouge(dec_path, ref_path) print(ROUGE) with open(join(args.path, 'ROUGE/{}.txt'.format(ckpts[i])), 'w') as f: print(ROUGE, file=f)
def decode_entity(save_path, model_dir, split, batch_size, beam_size, diverse, max_len, cuda, sc, min_len): start = time() # setup model with open(join(model_dir, 'meta.json')) as f: meta = json.loads(f.read()) if meta['net_args']['abstractor'] is None: #if not meta['net_args'].__contains__('abstractor'): # NOTE: if no abstractor is provided then # the whole model would be extractive summarization assert beam_size == 1 abstractor = identity else: if beam_size == 1: abstractor = Abstractor(join(model_dir, 'abstractor'), max_len, cuda) else: abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len, cuda, min_len=min_len) if sc: extractor = SCExtractor(model_dir, cuda=cuda, entity=True) else: extractor = RLExtractor(model_dir, cuda=cuda) # setup loader def coll(batch): batch = list(filter(bool, batch)) return batch if args.key == 1: key = 'filtered_rule1_input_mention_cluster' elif args.key == 2: key = 'filtered_rule23_6_input_mention_cluster' else: raise Exception dataset = DecodeDatasetEntity(split, key) n_data = len(dataset) loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=coll ) # prepare save paths and logs os.makedirs(join(save_path, 'output')) dec_log = {} dec_log['abstractor'] = meta['net_args']['abstractor'] dec_log['extractor'] = meta['net_args']['extractor'] dec_log['rl'] = True dec_log['split'] = split dec_log['beam'] = beam_size dec_log['diverse'] = diverse with open(join(save_path, 'log.json'), 'w') as f: json.dump(dec_log, f, indent=4) # Decoding if sc: i = 0 length = 0 sent_selected = 0 with torch.no_grad(): for i_debug, raw_input_batch in enumerate(loader): raw_article_batch, clusters = zip(*raw_input_batch) tokenized_article_batch = map(tokenize(None), raw_article_batch) #processed_clusters = map(preproc(list(tokenized_article_batch), clusters)) #processed_clusters = list(zip(*processed_clusters)) ext_arts = [] ext_inds = [] pre_abs = [] beam_inds = [] for raw_art_sents, raw_cls in zip(tokenized_article_batch, clusters): processed_clusters = preproc(raw_art_sents, raw_cls) ext = extractor((raw_art_sents, processed_clusters))[:] # exclude EOE sent_selected += len(ext) if not ext: # use top-3 if nothing is extracted # in some rare cases rnn-ext does not extract at all ext = list(range(3))[:len(raw_art_sents)] else: ext = [i for i in ext] ext_art = list(map(lambda i: raw_art_sents[i], ext)) pre_abs.append([word for sent in ext_art for word in sent]) beam_inds += [(len(beam_inds), 1)] if beam_size > 1: # all_beams = abstractor(ext_arts, beam_size, diverse) # dec_outs = rerank_mp(all_beams, ext_inds) all_beams = abstractor(pre_abs, beam_size, diverse=1.0) dec_outs = rerank_mp(all_beams, beam_inds) else: dec_outs = abstractor(pre_abs) for dec_out in dec_outs: dec_out = sent_tokenize(' '.join(dec_out)) ext = [sent.split(' ') for sent in dec_out] ext_inds += [(len(ext_arts), len(ext))] ext_arts += ext dec_outs = ext_arts assert i == batch_size*i_debug for j, n in ext_inds: decoded_sents = [' '.join(dec) for dec in dec_outs[j:j+n]] with open(join(save_path, 'output/{}.dec'.format(i)), 'w') as f: f.write(make_html_safe('\n'.join(decoded_sents))) i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i/n_data*100, timedelta(seconds=int(time()-start)) ), end='') length += len(decoded_sents) else: i = 0 length = 0 with torch.no_grad(): for i_debug, raw_article_batch in enumerate(loader): tokenized_article_batch = map(tokenize(None), raw_article_batch) ext_arts = [] ext_inds = [] for raw_art_sents in tokenized_article_batch: ext = extractor(raw_art_sents)[:-1] # exclude EOE if not ext: # use top-5 if nothing is extracted # in some rare cases rnn-ext does not extract at all ext = list(range(5))[:len(raw_art_sents)] else: ext = [i.item() for i in ext] ext_inds += [(len(ext_arts), len(ext))] ext_arts += [raw_art_sents[i] for i in ext] if beam_size > 1: all_beams = abstractor(ext_arts, beam_size, diverse) dec_outs = rerank_mp(all_beams, ext_inds) else: dec_outs = abstractor(ext_arts) assert i == batch_size*i_debug for j, n in ext_inds: decoded_sents = [' '.join(dec) for dec in dec_outs[j:j+n]] with open(join(save_path, 'output/{}.dec'.format(i)), 'w') as f: f.write(make_html_safe('\n'.join(decoded_sents))) i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i/n_data*100, timedelta(seconds=int(time()-start)) ), end='') length += len(decoded_sents) print('average summary length:', length / i) print('average sentence selected:', sent_selected)
def decode(save_path, abs_dir, ext_dir, split, batch_size, max_len, cuda, min_len): start = time() # setup model if abs_dir is None: # NOTE: if no abstractor is provided then # the whole model would be extractive summarization abstractor = identity else: #abstractor = Abstractor(abs_dir, max_len, cuda) abstractor = BeamAbstractor(abs_dir, max_len, cuda, min_len, reverse=args.reverse) if ext_dir is None: # NOTE: if no exstractor is provided then # it would be the lead-N extractor extractor = lambda art_sents: list(range(len(art_sents)))[:MAX_ABS_NUM] else: if args.no_force_ext: extractor = Extractor(ext_dir, max_ext=MAX_ABS_NUM, cuda=cuda, force_ext=not args.no_force_ext) else: extractor = Extractor(ext_dir, max_ext=MAX_ABS_NUM, cuda=cuda) # setup loader def coll(batch): articles = list(filter(bool, batch)) return articles dataset = DecodeDataset(split) n_data = len(dataset) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=coll) os.makedirs(save_path) # prepare save paths and logs dec_log = {} dec_log['abstractor'] = (None if abs_dir is None else json.load( open(join(abs_dir, 'meta.json')))) dec_log['extractor'] = (None if ext_dir is None else json.load( open(join(ext_dir, 'meta.json')))) dec_log['rl'] = False dec_log['split'] = split if abs_dir is not None: dec_log['beam'] = 5 # greedy decoding only beam_size = 5 else: dec_log['beam'] = 1 beam_size = 1 with open(join(save_path, 'log.json'), 'w') as f: json.dump(dec_log, f, indent=4) print(dec_log['extractor']) if dec_log['extractor']['net_args'][ 'stop'] == False and not args.no_force_ext: for i in range(MAX_ABS_NUM + 1): os.makedirs(join(save_path, 'output_{}'.format(i))) else: os.makedirs(join(save_path, 'output')) # Decoding i = 0 length = 0 with torch.no_grad(): for i_debug, raw_article_batch in enumerate(loader): tokenized_article_batch = map(tokenize(None), raw_article_batch) ext_arts = [] ext_inds = [] pre_abs = [] beam_inds = [] for raw_art_sents in tokenized_article_batch: ext = extractor(raw_art_sents) ext_art = list(map(lambda i: raw_art_sents[i], ext)) pre_abs.append([word for sent in ext_art for word in sent]) beam_inds += [(len(beam_inds), 1)] if beam_size > 1: all_beams = abstractor(pre_abs, beam_size, diverse=1.0) dec_outs = rerank_mp(all_beams, beam_inds) else: dec_outs = abstractor(pre_abs) for dec_out in dec_outs: dec_out = sent_tokenize(' '.join(dec_out)) ext = [sent.split(' ') for sent in dec_out] ext_inds += [(len(ext_arts), len(ext))] ext_arts += ext if dec_log['extractor']['net_args'][ 'stop'] == False and not args.no_force_ext: dec_outs = ext_arts assert i == batch_size * i_debug for j, n in ext_inds: decoded_sents = [ ' '.join(dec) for dec in dec_outs[j:j + n] ] for k, dec_str in enumerate(decoded_sents): if k > MAX_ABS_NUM - 2: break with open( join(save_path, 'output_{}/{}.dec'.format(k, i)), 'w') as f: f.write(make_html_safe(dec_str)) i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i / n_data * 100, timedelta(seconds=int(time() - start))), end='') else: dec_outs = ext_arts assert i == batch_size * i_debug for j, n in ext_inds: decoded_sents = [ ' '.join(dec) for dec in dec_outs[j:j + n] ] with open(join(save_path, 'output/{}.dec'.format(i)), 'w') as f: f.write(make_html_safe('\n'.join(decoded_sents))) i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i / n_data * 100, timedelta(seconds=int(time() - start))), end='') length += len(decoded_sents) print('average summary length:', length / i)
def decode(save_path, model_dir, split, batch_size, beam_size, diverse, max_len, cuda): start = time() # setup model with open(join(model_dir, 'meta.json')) as f: meta = json.loads(f.read()) if meta['net_args']['abstractor'] is None: # NOTE: if no abstractor is provided then # the whole model would be extractive summarization assert beam_size == 1 abstractor = identity else: if beam_size == 1: abstractor = Abstractor(join(model_dir, 'abstractor'), max_len, cuda) else: abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len, cuda) extractor = RLExtractor(model_dir, cuda=cuda) # setup loader def coll(batch): articles = list(filter(bool, batch)) return articles dataset = DecodeDataset(split) n_data = len(dataset) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=coll) # prepare save paths and logs # os.makedirs(join(save_path, 'output')) dec_log = {} dec_log['abstractor'] = meta['net_args']['abstractor'] dec_log['extractor'] = meta['net_args']['extractor'] dec_log['rl'] = True dec_log['split'] = split dec_log['beam'] = beam_size dec_log['diverse'] = diverse with open(join(save_path, 'log.json'), 'w') as f: json.dump(dec_log, f, indent=4) # Decoding i = 0 count = 0 with torch.no_grad(): for i_debug, raw_article_batch in enumerate(loader): tokenized_article_batch = map(tokenize(None), raw_article_batch) ext_arts = [] ext_inds = [] for raw_art_sents in tokenized_article_batch: ext = extractor(raw_art_sents)[:-1] # exclude EOE if not ext: # use top-5 if nothing is extracted # in some rare cases rnn-ext does not extract at all ext = list(range(5))[:len(raw_art_sents)] else: ext = [i.item() for i in ext] ext_inds += [(len(ext_arts), len(ext))] ext_arts += [raw_art_sents[i] for i in ext] if beam_size > 1: all_beams = abstractor(ext_arts, beam_size, diverse) for ind_file, (start, finish) in enumerate(ext_inds): article_beams = all_beams[start:start + finish] file = {} for ind_sent, sent in enumerate(article_beams): file[ind_sent] = defaultdict(list) sentence = " ".join(ext_arts[start + ind_sent]) file[ind_sent]['sentence'].append(sentence) for hypothesis in sent: file[ind_sent]['summarizer_logprob'].append( hypothesis.logprob) file[ind_sent]['hypotheses'].append(" ".join( hypothesis.sequence)) with open( os.path.join('exported_beams', '{}.json'.format(count + ind_file)), 'w') as f: json.dump(file, f, ensure_ascii=False) count += batch_size