def decode(raw_sentences, extractor, abstractor, beam_size, diverse=1, token_threshold=5, postpro=False, debug=False): with torch.no_grad(): start = time() # setup model tokenized_sentences = tokenize(None, raw_sentences) ext = extractor(tokenized_sentences)[:-1] # exclude EOE if not ext: # use top-5 if nothing is extracted # in some rare cases rnn-ext does not extract at all ext = list(range(5))[:len(tokenized_sentences)] else: ext = [i.item() for i in ext] ext_sentences = [tokenized_sentences[i] for i in ext] if beam_size > 1: all_beams = abstractor(ext_sentences, beam_size, diverse) dec_outs = rerank_mp(all_beams, [(0, len(ext_sentences))], debug=debug) else: dec_outs = abstractor(ext_sentences, debug=debug) attns = None if debug: dec_outs, attns = dec_outs attns = [[t[:len(ext_sentences[i])] for t in attn] for i, attn in enumerate(attns)] source_col_name = 'source' attns = [ pd.DataFrame({ source_col_name: ext_sentences[i], **{dec_outs[i][j]: t for j, t in enumerate(attn)}, }).set_index(source_col_name) for i, attn in enumerate(attns) ] if postpro: decoded_sentences = postprocess(dec_outs, token_threshold) else: decoded_sentences = [' '.join(dec) for dec in dec_outs] logging.info('decoded {} sentences in {:.3f}s'.format( len(raw_sentences), time() - start)) if debug: return (ext, decoded_sentences), attns return ext, decoded_sentences
def decode(args, split): decode_path = join(args.path, 'abs_decode') os.makedirs(decode_path, exist_ok=True) dataset = CnnDmDataset(split, DATA_DIR) print('Generating abstracts for {} dataset'.format(split)) for i in tqdm(range(len(dataset))): file_path = join(decode_path, '{}.dec'.format(i)) if not args.force and os.path.exists(file_path): continue js_data = dataset[i] art_sents, extracts = (js_data['article'], js_data['extracted']) if not extracts: abs_results = [] else: abs_results = abstract_callback(args, tokenize(None, (art_sents[i] for i in extracts))) with open(file_path, 'w', encoding='utf8') as f: f.write('\n'.join(abs_results))
def read_file(file_dir, file_name): with open(join(file_dir, file_name)) as f: return tokenize(None, f.read().splitlines())
def read_data_file(file_dir, file_name): with open(join(file_dir, file_name)) as f: return tokenize(None, json.load(f)['article'])
def coll(batch): art_batch, abs_batch = unzip(batch) art_sents = list(filter(bool, map(tokenize(None), art_batch))) abs_sents = list(filter(bool, map(tokenize(None), abs_batch))) return art_sents, abs_sents
def decode(save_path, abs_dir, ext_dir, split, batch_size, max_len, cuda): start = time() # setup model if abs_dir is None: # NOTE: if no abstractor is provided then # the whole model would be extractive summarization abstractor = identity else: abstractor = Abstractor(abs_dir, max_len, cuda) if ext_dir is None: # NOTE: if no abstractor is provided then # it would be the lead-N extractor extractor = lambda art_sents: list(range(len(art_sents)))[:MAX_ABS_NUM] else: extractor = Extractor(ext_dir, max_ext=MAX_ABS_NUM, cuda=cuda) # setup loader def coll(batch): articles = list(filter(bool, batch)) return articles dataset = DecodeDataset(split, DATASET_DIR) n_data = len(dataset) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=coll) # prepare save paths and logs for i in range(MAX_ABS_NUM): os.makedirs(join(save_path, 'output_{}'.format(i))) dec_log = { 'abstractor': (None if abs_dir is None else json.load( open(join(abs_dir, 'meta.json')))), 'extractor': (None if ext_dir is None else json.load( open(join(ext_dir, 'meta.json')))), 'rl': False, 'split': split, 'beam': 1 } with open(join(save_path, 'log.json'), 'w') as f: json.dump(dec_log, f, indent=4) # Decoding i = 0 with torch.no_grad(): for i_debug, raw_article_batch in enumerate(loader): tokenized_article_batch = map(tokenize(None), raw_article_batch) ext_arts = [] ext_inds = [] for raw_art_sents in tokenized_article_batch: ext = extractor(raw_art_sents) ext_inds += [(len(ext_arts), len(ext))] ext_arts += list(map(lambda i: raw_art_sents[i], ext)) dec_outs = abstractor(ext_arts) assert i == batch_size * i_debug for j, n in ext_inds: decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]] for k, dec_str in enumerate(decoded_sents): with open(join(save_path, 'output_{}/{}.dec'.format(k, i)), 'w') as f: f.write(make_html_safe(dec_str)) i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i / n_data * 100, timedelta(seconds=int(time() - start))), end='') print()
def decode(save_path, model_dir, split, batch_size, beam_size, diverse, max_len, cuda): start = time() # setup model with open(join(model_dir, 'meta.json')) as f: meta = json.loads(f.read()) if meta['net_args']['abstractor'] is None: # NOTE: if no abstractor is provided then # the whole model would be extractive summarization assert beam_size == 1 abstractor = identity else: if beam_size == 1: abstractor = Abstractor(join(model_dir, 'abstractor'), max_len, cuda) else: abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len, cuda) extractor = RLExtractor(model_dir, cuda=cuda) # setup loader def coll(batch): articles = list(filter(bool, batch)) return articles dataset = DecodeDataset(split, DATASET_DIR) n_data = len(dataset) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=coll) # prepare save paths and logs os.makedirs(join(save_path, 'output'), exist_ok=True) dec_log = { 'abstractor': meta['net_args']['abstractor'], 'extractor': meta['net_args']['extractor'], 'rl': True, 'split': split, 'beam': beam_size, 'diverse': diverse } with open(join(save_path, 'log.json'), 'w') as f: json.dump(dec_log, f, indent=4) # Decoding i = 0 with torch.no_grad(): for i_debug, raw_article_batch in enumerate(loader): tokenized_article_batch = map(tokenize(None), raw_article_batch) ext_arts = [] ext_inds = [] for raw_art_sents in tokenized_article_batch: ext = extractor(raw_art_sents)[:-1] # exclude EOE if not ext: # use top-5 if nothing is extracted # in some rare cases rnn-ext does not extract at all ext = list(range(5))[:len(raw_art_sents)] else: ext = [i.item() for i in ext] ext_inds += [(len(ext_arts), len(ext))] ext_arts += [raw_art_sents[i] for i in ext] if beam_size > 1: all_beams = abstractor(ext_arts, beam_size, diverse) dec_outs = rerank_mp(all_beams, ext_inds) else: dec_outs = abstractor(ext_arts) assert i == batch_size * i_debug for j, n in ext_inds: decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]] with open(join(save_path, 'output/{}.dec'.format(i)), 'w') as f: f.write(make_html_safe('\n'.join(decoded_sents))) i += 1 print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( i, n_data, i / n_data * 100, timedelta(seconds=int(time() - start))), end='') print()