def configure_net_entity(abs_dir, ext_dir, cuda, sc, tv):
    """ load pretrained sub-modules and build the actor-critic network"""
    # load pretrained abstractor model
    if abs_dir is not None:
        abstractor = Abstractor(abs_dir, MAX_ABS_LEN, cuda)
    else:
        abstractor = identity

    # load ML trained extractor net and buiild RL agent
    extractor, agent_vocab = load_ext_net(ext_dir)
    if sc:
        agent = SelfCriticEntity(extractor,
                                 ArticleBatcher(agent_vocab, cuda),
                                 time_variant=tv)
    else:
        raise Exception('actor critic entity model not implemented')
    if cuda:
        agent = agent.cuda()

    net_args = {}
    net_args['abstractor'] = (None if abs_dir is None else json.load(
        open(join(abs_dir, 'meta.json'))))
    net_args['extractor'] = json.load(open(join(ext_dir, 'meta.json')))

    return agent, agent_vocab, abstractor, net_args
Esempio n. 2
0
def configure_net(abs_dir, ext_dir, cuda, sc, tv, rl_dir=''):
    """ load pretrained sub-modules and build the actor-critic network"""
    # load pretrained abstractor model
    if abs_dir is not None:
        abstractor = Abstractor(abs_dir, MAX_ABS_LEN, cuda)
    else:
        abstractor = identity

    # load ML trained extractor net and buiild RL agent
    extractor, agent_vocab = load_ext_net(ext_dir)
    if sc:
        agent = SelfCritic(extractor,
                           ArticleBatcher(agent_vocab, cuda),
                           time_variant=tv)
    else:
        agent = ActorCritic(extractor._sent_enc, extractor._art_enc,
                            extractor._extractor,
                            ArticleBatcher(agent_vocab, cuda))

    if rl_dir != '':
        ckpt = load_best_ckpt(rl_dir, reverse=True)
        agent.load_state_dict(ckpt)

    if cuda:
        agent = agent.cuda()

    net_args = {}
    net_args['abstractor'] = (None if abs_dir is None else json.load(
        open(join(abs_dir, 'meta.json'))))
    net_args['extractor'] = json.load(open(join(ext_dir, 'meta.json')))
    print('agent:', agent)

    return agent, agent_vocab, abstractor, net_args
Esempio n. 3
0
def configure_net(abs_dir, ext_dir, cuda, net_type='ml_rnn_extractor'):
    """ load pretrained sub-modules and build the actor-critic network"""
    # load pretrained abstractor model
    if abs_dir is not None:
        abstractor = Abstractor(abs_dir, MAX_ABS_LEN, cuda)
    else:
        abstractor = identity

    # load ML trained extractor net and buiild RL agent
    extractor, agent_vocab = load_ext_net(ext_dir)
    agent = ActorCritic(extractor._sent_enc,
                        extractor._art_enc,
                        extractor._extractor,
                        ArticleBatcher(agent_vocab, cuda, net_type=net_type))
    # agent = ActorCriticPreSumm(extractor._sent_enc,
    #                     extractor._art_enc,
    #                     extractor._extractor,
    #                     ArticleBatcher(agent_vocab, cuda))
    if cuda:
        agent = agent.cuda()

    net_args = {}
    net_args['abstractor'] = (None if abs_dir is None
                              else json.load(open(join(abs_dir, 'meta.json'))))
    net_args['extractor'] = json.load(open(join(ext_dir, 'meta.json')))

    return agent, agent_vocab, abstractor, net_args
def configure_net(abs_dir, ext_dir, cuda):
    """ load pretrained sub-modules and build the actor-critic network"""
    # load pretrained abstractor model
    if abs_dir is not None:
        abstractor_sent = Abstractor(abs_dir, MAX_ABS_LEN, cuda)
        # abs_dir = abs_dir.split(';')
        # abstractor_sent = Abstractor(abs_dir[0], MAX_ABS_LEN, cuda)
        # abstractor_doc = Abstractor(abs_dir[1], MAX_ABS_LEN, cuda)
    else:
        abstractor = identity
    # load ML trained extractor net and buiild RL agent
    extractor, agent_vocab = load_ext_net(ext_dir)
    agent = ActorCritic(extractor._sent_enc,
                        extractor._art_enc, extractor._extractor,
                        ArticleBatcher(agent_vocab, cuda))

    target_agent = ActorCritic(extractor._sent_enc, extractor._art_enc,
                               extractor._extractor,
                               ArticleBatcher(agent_vocab, cuda))

    target_agent.load_state_dict(agent.state_dict())

    if cuda:
        agent = agent.cuda()
        target_agent = target_agent.cuda()
    net_args = {}
    net_args['abstractor'] = (None if abs_dir is None else json.load(
        open(join(abs_dir, 'meta.json'))))
    net_args['extractor'] = json.load(open(join(ext_dir, 'meta.json')))

    return agent, target_agent, agent_vocab, abstractor_sent, net_args
Esempio n. 5
0
def configure_net_graph(abs_dir,
                        ext_dir,
                        cuda,
                        docgraph=True,
                        paragraph=False):
    """ load pretrained sub-modules and build the actor-critic network"""
    # load pretrained abstractor model
    assert not all([docgraph, paragraph])
    if abs_dir is not None:
        abstractor = Abstractor(abs_dir, MAX_ABS_LEN, cuda)
    else:
        abstractor = identity

    # load ML trained extractor net and buiild RL agent
    extractor, agent_vocab = load_ext_net(ext_dir)

    agent = SelfCriticGraph(extractor, ArticleBatcherGraph(agent_vocab, cuda),
                            cuda, docgraph, paragraph)

    if cuda:
        agent = agent.cuda()

    net_args = {}
    net_args['abstractor'] = (None if abs_dir is None else json.load(
        open(join(abs_dir, 'meta.json'))))
    net_args['extractor'] = json.load(open(join(ext_dir, 'meta.json')))

    return agent, agent_vocab, abstractor, net_args
Esempio n. 6
0
def main(article_path, model_dir, batch_size, beam_size, diverse, max_len,
         cuda):
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'), max_len,
                                    cuda)
        else:
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len,
                                        cuda)
    extractor = RLExtractor(model_dir, cuda=cuda)
    with open(article_path) as f:
        raw_article_batch = f.readlines()
    tokenized_article_batch = map(tokenize(None), raw_article_batch)
    ext_arts = []
    ext_inds = []
    for raw_art_sents in tokenized_article_batch:
        print(raw_art_sents)
        ext = extractor(raw_art_sents)[:-1]  # exclude EOE
        if not ext:
            # use top-5 if nothing is extracted
            # in some rare cases rnn-ext does not extract at all
            ext = list(range(5))[:len(raw_art_sents)]
        else:
            ext = [i.item() for i in ext]
        ext_inds += [(len(ext_arts), len(ext))]
        ext_arts += [raw_art_sents[i] for i in ext]
    if beam_size > 1:
        all_beams = abstractor(ext_arts, beam_size, diverse)
        dec_outs = rerank_mp(all_beams, ext_inds)
    else:
        dec_outs = abstractor(ext_arts)
    # assert i == batch_size*i_debug
    for j, n in ext_inds:
        decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]]
        print(decoded_sents)
Esempio n. 7
0
def configure_net(abs_dir, ext_dir, cuda):
    """ load pretrained sub-modules and build the actor-critic network"""
    # load pretrained abstractor model
    if abs_dir is not None:
        abstractor = Abstractor(abs_dir, MAX_ABS_LEN, cuda)
    else:
        abstractor = lambda x, y: x
    #print(abstractor([[], [0,0,0,0,0,0]], [[],[1,1,1,1,1,1]]))
    #exit(0)
    # load ML trained extractor net and buiild RL agent
    extractor, agent_vocab = load_ext_net(ext_dir)
    agent = ActorCritic(extractor, ArticleBatcher(agent_vocab, cuda))
    if cuda:
        agent = agent.cuda()

    net_args = {}
    net_args['abstractor'] = (None if abs_dir is None else json.load(
        open(join(abs_dir, 'meta.json'))))
    net_args['extractor'] = json.load(open(join(ext_dir, 'meta.json')))

    return agent, agent_vocab, abstractor, net_args
Esempio n. 8
0
def configure_net(abs_dir, ext_dir, cuda, aux_device, bert_max_len, args):
    """ load pretrained sub-modules and build the actor-critic network"""
    # load pretrained abstractor model
    if abs_dir is not None:
        abstractor = Abstractor(abs_dir, MAX_ABS_LEN, cuda)
    else:
        abstractor = identity

    # load ML trained extractor net and build RL agent
    extractor, agent_vocab, use_bert, bert_config = load_ext_net(
        ext_dir, aux_device, args)

    if use_bert:
        bert_type, tokenizer_cache = bert_config
        bert_tokenizer = BertTokenizer.from_pretrained(
            bert_type, cache_dir=tokenizer_cache)
        agent = BertActorCritic(extractor._sent_enc,
                                extractor._art_enc,
                                extractor._extractor,
                                BertArticleBatcher(bert_tokenizer,
                                                   bert_max_len,
                                                   args.bert_max_sent, cuda),
                                cuda=cuda,
                                aux_device=aux_device)

    else:
        agent = ActorCritic(extractor._sent_enc, extractor._art_enc,
                            extractor._extractor,
                            ArticleBatcher(agent_vocab, cuda))
        if cuda:
            agent = agent.cuda()

    net_args = {}
    net_args['abstractor'] = (None if abs_dir is None else json.load(
        open(join(abs_dir, 'meta.json'))))
    net_args['extractor'] = json.load(open(join(ext_dir, 'meta.json')))

    return agent, agent_vocab, abstractor, net_args
def load_rl_ckpt(abs_dir, ext_dir, cuda):
    ext_meta = json.load(open(join(ext_dir, 'meta.json')))
    assert ext_meta['net'] == 'rnn-ext_abs_rl'
    ext_args = ext_meta['net_args']['extractor']['net_args']
    vocab = pkl.load(open(join(ext_dir, 'agent_vocab.pkl'), 'rb'))
    extractor = PtrExtractSumm(**ext_args)
    abstractor_sent = Abstractor(abs_dir, MAX_ABS_LEN, cuda)

    agent = ActorCritic(extractor._sent_enc, extractor._art_enc,
                        extractor._extractor, ArticleBatcher(vocab, cuda))

    target_agent = ActorCritic(extractor._sent_enc,
                               extractor._art_enc, extractor._extractor,
                               ArticleBatcher(vocab, cuda))

    ext_ckpt = load_best_ckpt(ext_dir, reverse=True)
    agent.load_state_dict(ext_ckpt)

    device = torch.device('cuda' if cuda else 'cpu')
    agent = agent.to(device)
    target_agent = target_agent.to(device)

    return agent, target_agent, vocab, abstractor_sent, ext_meta
def decode(save_path, model_dir, split, batch_size,
           beam_size, diverse, max_len, cuda):
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'),
                                    max_len, cuda)
        else:
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'),
                                        max_len, cuda)
    
    extractor = RLExtractor(model_dir, cuda=cuda)

    # setup loader
    def coll(batch):
        articles, abstract, extracted = unzip(batch)
        articles = list(filter(bool, articles))
        abstract = list(filter(bool, abstract))
        extracted =  list(filter(bool, extracted))
        return articles, abstract, extracted

    dataset = DecodeDataset(split)
    n_data = len(dataset[0]) # article sentence
    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=4,
        collate_fn=coll
    )
    # prepare save paths and logs
    if os.path.exists(join(save_path, 'output')):
        pass
    else:
        os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse

    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)
    
    file_path = os.path.join(save_path, 'Attention')
    act_path = os.path.join(save_path, 'Actions')

    header = "index,rouge_score1,rouge_score2,"+\
    "rouge_scorel,dec_sent_nums,abs_sent_nums,doc_sent_nums,doc_words_nums,"+\
    "ext_words_nums, abs_words_nums, diff,"+\
    "recall, precision, less_rewrite, preserve_action, rewrite_action, each_actions,"+\
    "top3AsAns, top3AsGold, any_top2AsAns, any_top2AsGold,true_rewrite,true_preserve\n"


    if not os.path.exists(file_path):
        print('create dir:{}'.format(file_path))
        os.makedirs(file_path)

    if not os.path.exists(act_path):
        print('create dir:{}'.format(act_path))
        os.makedirs(act_path)

    with open(join(save_path,'_statisticsDecode.log.csv'),'w') as w:
        w.write(header)  
        
    # Decoding
    i = 0
    with torch.no_grad():
        for i_debug, (raw_article_batch, raw_abstract_batch, extracted_batch) in enumerate(loader):
            tokenized_article_batch = map(tokenize(None), raw_article_batch)
            tokenized_abstract_batch = map(tokenize(None), raw_abstract_batch)
            token_nums_batch = list(map(token_nums(None), raw_article_batch))

            ext_nums = []
            ext_arts = []
            ext_inds = []
            rewrite_less_rouge = []
            dec_outs_act = []
            ext_acts = []
            abs_collections = []
            ext_collections = []

            # 抽句子
            for ind, (raw_art_sents, abs_sents) in enumerate(zip(tokenized_article_batch ,tokenized_abstract_batch)):

                (ext, (state, act_dists)), act = extractor(raw_art_sents)  # exclude EOE
                extracted_state = state[extracted_batch[ind]]
                attn = torch.softmax(state.mm(extracted_state.transpose(1,0)),dim=-1)
                # (_, abs_state), _ = extractor(abs_sents)  # exclude EOE
                
                def plot_actDist(actons, nums):
                    print('indiex: {} distribution ...'.format(nums))
                    # Write MDP State Attention weight matrix   
                    file_name = os.path.join(act_path, '{}.attention.pdf'.format(nums))
                    pdf_pages = PdfPages(file_name)
                    plot_attention(actons.cpu().numpy(), name='{}-th artcle'.format(nums),
                        X_label=list(range(len(raw_art_sents))), Y_label=list(range(len(ext))),
                        dirpath=save_path, pdf_page=pdf_pages,action=True)
                    pdf_pages.close()
                # plot_actDist(torch.stack(act_dists, dim=0), nums=ind+i)

                def plot_attn():
                    print('indiex: {} write_attention_pdf ...'.format(i + ind))
                    # Write MDP State Attention weight matrix   
                    file_name = os.path.join(file_path, '{}.attention.pdf'.format(i+ind))
                    pdf_pages = PdfPages(file_name)
                    plot_attention(attn.cpu().numpy(), name='{}-th artcle'.format(i+ind),
                        X_label=extracted_batch[ind],Y_label=list(range(len(raw_art_sents))),
                        dirpath=save_path, pdf_page=pdf_pages) 
                    pdf_pages.close()
                # plot_attn()

                ext = ext[:-1]
                act = act[:-1]

                if not ext:
                    # use top-5 if nothing is extracted
                    # in some rare cases rnn-ext does not extract at all
                    ext = list(range(5))[:len(raw_art_sents)]
                    act = list([1]*5)[:len(raw_art_sents)]
                else:
                    ext = [i.item() for i in ext]
                    act = [i.item() for i in act]

                ext_nums.append(ext)

                ext_inds += [(len(ext_arts), len(ext))] # [(0,5),(5,7),(7,3),...]
                ext_arts += [raw_art_sents[k] for k in ext]
                ext_acts += [k for k in act]

                # 計算累計的句子
                ext_collections += [sum(ext_arts[ext_inds[-1][0]:ext_inds[-1][0]+k+1],[]) for k in range(ext_inds[-1][1])]

                abs_collections += [sum(abs_sents[:k+1],[]) if k<len(abs_sents) 
                                        else sum(abs_sents[0:len(abs_sents)],[]) 
                                        for k in range(ext_inds[-1][1])]

            if beam_size > 1: # do n times abstract
                all_beams = abstractor(ext_arts, beam_size, diverse)
                dec_outs = rerank_mp(all_beams, ext_inds)

                dec_collections = [[sum(dec_outs[pos[0]:pos[0]+k+1],[]) for k in range(pos[1])] for pos in ext_inds]
                dec_collections = [x for sublist in dec_collections for x in sublist]
                for index, chooser in enumerate(ext_acts):
                    if chooser == 0:
                        dec_outs_act += [dec_outs[index]]
                    else:
                        dec_outs_act += [ext_arts[index]]

                assert len(ext_collections)==len(dec_collections)==len(abs_collections)
                for ext, dec, abss, act in zip(ext_collections, dec_collections, abs_collections, ext_acts):
                    # for each sent in extracted digest
                    # All abstract mapping
                    rouge_before_rewriten = compute_rouge_n(ext, abss, n=1)
                    rouge_after_rewriten = compute_rouge_n(dec, abss, n=1)
                    diff_ins = rouge_before_rewriten - rouge_after_rewriten
                    rewrite_less_rouge.append(diff_ins)
            
            else: # do 1st abstract
                dec_outs = abstractor(ext_arts)
                dec_collections = [[sum(dec_outs[pos[0]:pos[0]+k+1],[]) for k in range(pos[1])] for pos in ext_inds]
                dec_collections = [x for sublist in dec_collections for x in sublist]
                for index, chooser in enumerate(ext_acts):
                    if chooser == 0:
                        dec_outs_act += [dec_outs[index]]
                    else:
                        dec_outs_act += [ext_arts[index]]
                # dec_outs_act = dec_outs
                # dec_outs_act = ext_arts
                assert len(ext_collections)==len(dec_collections)==len(abs_collections)
                for ext, dec, abss, act in zip(ext_collections, dec_collections, abs_collections, ext_acts):
                    # for each sent in extracted digest
                    # All abstract mapping
                    rouge_before_rewriten = compute_rouge_n(ext, abss, n=1)
                    rouge_after_rewriten = compute_rouge_n(dec, abss, n=1)
                    diff_ins = rouge_before_rewriten - rouge_after_rewriten
                    rewrite_less_rouge.append(diff_ins)

            assert i == batch_size*i_debug

            for iters, (j, n) in enumerate(ext_inds):        
                
                do_right_rewrite = sum([1 for rouge, action in zip(rewrite_less_rouge[j:j+n], ext_acts[j:j+n]) if rouge<0 and action==0])
                do_right_preserve = sum([1 for rouge, action in zip(rewrite_less_rouge[j:j+n], ext_acts[j:j+n]) if rouge>=0 and action==1])
                
                decoded_words_nums = [len(dec) for dec in dec_outs_act[j:j+n]]
                ext_words_nums = [token_nums_batch[iters][x] for x in range(len(token_nums_batch[iters])) if x in ext_nums[iters]]

                # 皆取extracted label 
                # decoded_sents = [raw_article_batch[iters][x] for x in extracted_batch[iters]]         
                # 統計數據 [START]
                decoded_sents = [' '.join(dec) for dec in dec_outs_act[j:j+n]]
                rouge_score1 = compute_rouge_n(' '.join(decoded_sents),' '.join(raw_abstract_batch[iters]),n=1)
                rouge_score2 = compute_rouge_n(' '.join(decoded_sents),' '.join(raw_abstract_batch[iters]),n=2)
                rouge_scorel = compute_rouge_l(' '.join(decoded_sents),' '.join(raw_abstract_batch[iters]))
                
                dec_sent_nums = len(decoded_sents)
                abs_sent_nums = len(raw_abstract_batch[iters])
                doc_sent_nums = len(raw_article_batch[iters])
                
                doc_words_nums = sum(token_nums_batch[iters])
                ext_words_nums = sum(ext_words_nums)
                abs_words_nums = sum(decoded_words_nums)

                label_recall = len(set(ext_nums[iters]) & set(extracted_batch[iters])) / len(extracted_batch[iters])
                label_precision = len(set(ext_nums[iters]) & set(extracted_batch[iters])) / len(ext_nums[iters])
                less_rewrite = rewrite_less_rouge[j+n-1]
                dec_one_action_num = sum(ext_acts[j:j+n])
                dec_zero_action_num = n - dec_one_action_num

                ext_indices = '_'.join([str(i) for i in ext_nums[iters]])
                
                top3 = set([0,1,2]) <= set(ext_nums[iters])
                top3_gold = set([0,1,2]) <= set(extracted_batch[iters])
                
                # Any Top 2 
                top2 = set([0,1]) <= set(ext_nums[iters]) or set([1,2]) <= set(ext_nums[iters]) or set([0,2]) <= set(ext_nums[iters])
                top2_gold = set([0,1]) <= set(extracted_batch[iters]) or set([1,2]) <= set(extracted_batch[iters]) or set([0,2]) <= set(extracted_batch[iters])
                
                with open(join(save_path,'_statisticsDecode.log.csv'),'a') as w:
                    w.write('{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n'.format(i,rouge_score1,
                     rouge_score2, rouge_scorel, dec_sent_nums,
                      abs_sent_nums, doc_sent_nums, doc_words_nums,
                      ext_words_nums,abs_words_nums,(ext_words_nums - abs_words_nums),
                      label_recall, label_precision,
                      less_rewrite, dec_one_action_num, dec_zero_action_num, 
                      ext_indices, top3, top3_gold, top2, top2_gold,do_right_rewrite,do_right_preserve))
                # 統計數據 END

                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    decoded_sents = [i for i in decoded_sents if i!='']
                    if len(decoded_sents) > 0:
                        f.write(make_html_safe('\n'.join(decoded_sents)))
                    else:
                        f.write('')

                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i/n_data*100,
                    timedelta(seconds=int(time()-start))
                ), end='')
            
    print()
def decode(save_path, model_dir, split, batch_size, beam_size, diverse,
           max_len, cuda, sc, min_len):
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
        #if not meta['net_args'].__contains__('abstractor'):
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'), max_len,
                                    cuda)
        else:
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len,
                                        cuda, min_len)

    if sc:
        extractor = SCExtractor(model_dir, cuda=cuda)
    else:
        extractor = RLExtractor(model_dir, cuda=cuda)

    #check if use bert

    try:
        _bert = extractor._net._bert
    except:
        _bert = False
        print('no bert arg:')

    if _bert:
        tokenizer = BertTokenizer.from_pretrained(
            'bert-large-uncased-whole-word-masking')
        print('bert tokenizer loaded')

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles

    dataset = DecodeDataset(split)

    n_data = len(dataset)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=coll)

    # prepare save paths and logs
    os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    if sc:
        i = 0
        length = 0
        with torch.no_grad():
            for i_debug, raw_article_batch in enumerate(loader):
                tokenized_article_batch = map(tokenize(None),
                                              raw_article_batch)
                ext_arts = []
                ext_inds = []
                if _bert:
                    for raw_art_sents, raw_art in zip(tokenized_article_batch,
                                                      raw_article_batch):
                        tokenized_sents = [
                            tokenizer.tokenize(source_sent.lower())
                            for source_sent in raw_art
                        ]
                        tokenized_sents = [
                            tokenized_sent + ['[SEP]']
                            for tokenized_sent in tokenized_sents
                        ]
                        tokenized_sents[0] = ['[CLS]'] + tokenized_sents[0]
                        word_num = [
                            len(tokenized_sent)
                            for tokenized_sent in tokenized_sents
                        ]
                        truncated_word_num = []
                        total_count = 0
                        for num in word_num:
                            if total_count + num < MAX_LEN_BERT:
                                truncated_word_num.append(num)
                            else:
                                truncated_word_num.append(MAX_LEN_BERT -
                                                          total_count)
                                break
                            total_count += num
                        tokenized_sents = list(
                            concat(tokenized_sents))[:MAX_LEN_BERT]
                        tokenized_sents = tokenizer.convert_tokens_to_ids(
                            tokenized_sents)
                        art_sents = tokenize(None, raw_art)
                        _input = (art_sents, tokenized_sents,
                                  truncated_word_num)

                        ext = extractor(_input)[:]  # exclude EOE
                        if not ext:
                            # use top-3 if nothing is extracted
                            # in some rare cases rnn-ext does not extract at all
                            ext = list(range(3))[:len(raw_art_sents)]
                        else:
                            ext = [i for i in ext]
                        ext_inds += [(len(ext_arts), len(ext))]
                        ext_arts += [raw_art_sents[i] for i in ext]
                else:
                    for raw_art_sents in tokenized_article_batch:
                        ext = extractor(raw_art_sents)[:]  # exclude EOE
                        if not ext:
                            # use top-5 if nothing is extracted
                            # in some rare cases rnn-ext does not extract at all
                            ext = list(range(5))[:len(raw_art_sents)]
                        else:
                            ext = [i for i in ext]
                        ext_inds += [(len(ext_arts), len(ext))]
                        ext_arts += [raw_art_sents[i] for i in ext]
                if beam_size > 1:
                    all_beams = abstractor(ext_arts, beam_size, diverse)
                    dec_outs = rerank_mp(all_beams, ext_inds)
                else:
                    dec_outs = abstractor(ext_arts)
                assert i == batch_size * i_debug
                for j, n in ext_inds:
                    decoded_sents = [
                        ' '.join(dec) for dec in dec_outs[j:j + n]
                    ]
                    with open(join(save_path, 'output/{}.dec'.format(i)),
                              'w') as f:
                        f.write(make_html_safe('\n'.join(decoded_sents)))
                    i += 1
                    print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                        i, n_data, i / n_data * 100,
                        timedelta(seconds=int(time() - start))),
                          end='')
                    length += len(decoded_sents)
    else:
        i = 0
        length = 0
        with torch.no_grad():
            for i_debug, raw_article_batch in enumerate(loader):
                tokenized_article_batch = map(tokenize(None),
                                              raw_article_batch)
                ext_arts = []
                ext_inds = []
                for raw_art_sents in tokenized_article_batch:
                    ext = extractor(raw_art_sents)[:-1]  # exclude EOE
                    if not ext:
                        # use top-5 if nothing is extracted
                        # in some rare cases rnn-ext does not extract at all
                        ext = list(range(5))[:len(raw_art_sents)]
                    else:
                        ext = [i.item() for i in ext]
                    ext_inds += [(len(ext_arts), len(ext))]
                    ext_arts += [raw_art_sents[i] for i in ext]
                if beam_size > 1:
                    all_beams = abstractor(ext_arts, beam_size, diverse)
                    dec_outs = rerank_mp(all_beams, ext_inds)
                else:
                    dec_outs = abstractor(ext_arts)
                assert i == batch_size * i_debug
                for j, n in ext_inds:
                    decoded_sents = [
                        ' '.join(dec) for dec in dec_outs[j:j + n]
                    ]
                    with open(join(save_path, 'output/{}.dec'.format(i)),
                              'w') as f:
                        f.write(make_html_safe('\n'.join(decoded_sents)))
                    i += 1
                    print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                        i, n_data, i / n_data * 100,
                        timedelta(seconds=int(time() - start))),
                          end='')
                    length += len(decoded_sents)
    print('average summary length:', length / i)
Esempio n. 12
0
def decode(save_path, abs_dir, ext_dir, split, batch_size, max_len, cuda):
    start = time()
    # setup model
    if abs_dir is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        abstractor = identity
    else:
        abstractor = Abstractor(abs_dir, max_len, cuda)
    if ext_dir is None:
        # NOTE: if no abstractor is provided then
        #       it would be  the lead-N extractor
        extractor = lambda art_sents: list(range(len(art_sents)))[:MAX_ABS_NUM]
    else:
        extractor = Extractor(ext_dir, max_ext=MAX_ABS_NUM, cuda=cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles
    dataset = DecodeDataset(split)

    n_data = len(dataset)
    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=4,
        collate_fn=coll
    )

    # prepare save paths and logs
    for i in range(MAX_ABS_NUM):
        os.makedirs(join(save_path, 'output_{}'.format(i)))
    dec_log = {}
    dec_log['abstractor'] = (None if abs_dir is None
                             else json.load(open(join(abs_dir, 'meta.json'))))
    dec_log['extractor'] = (None if ext_dir is None
                            else json.load(open(join(ext_dir, 'meta.json'))))
    dec_log['rl'] = False
    dec_log['split'] = split
    dec_log['beam'] = 1  # greedy decoding only
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            tokenized_article_batch = map(tokenize(None), raw_article_batch)
            ext_arts = []
            ext_inds = []
            for raw_art_sents in tokenized_article_batch:
                ext = extractor(raw_art_sents)
                ext_inds += [(len(ext_arts), len(ext))]
                ext_arts += list(map(lambda i: raw_art_sents[i], ext))
            dec_outs = abstractor(ext_arts)
            assert i == batch_size*i_debug
            for j, n in ext_inds:
                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j+n]]
                for k, dec_str in enumerate(decoded_sents):
                    with open(join(save_path, 'output_{}/{}.dec'.format(k, i)),
                              'w') as f:
                        f.write(make_html_safe(dec_str))

                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i/n_data*100, timedelta(seconds=int(time()-start))
                ), end='')
    print()
def decode_entity(save_path, model_dir, split, batch_size,
           beam_size, diverse, max_len, cuda, sc, min_len):
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
    #if not meta['net_args'].__contains__('abstractor'):
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'),
                                    max_len, cuda)
        else:
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'),
                                        max_len, cuda, min_len=min_len)

    if sc:
        extractor = SCExtractor(model_dir, cuda=cuda, entity=True)
    else:
        extractor = RLExtractor(model_dir, cuda=cuda)

    # setup loader
    def coll(batch):
        batch = list(filter(bool, batch))
        return batch

    if args.key == 1:
        key = 'filtered_rule1_input_mention_cluster'
    elif args.key == 2:
        key = 'filtered_rule23_6_input_mention_cluster'
    else:
        raise Exception
    dataset = DecodeDatasetEntity(split, key)

    n_data = len(dataset)
    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=4,
        collate_fn=coll
    )

    # prepare save paths and logs
    os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    if sc:
        i = 0
        length = 0
        sent_selected = 0
        with torch.no_grad():
            for i_debug, raw_input_batch in enumerate(loader):
                raw_article_batch, clusters = zip(*raw_input_batch)
                tokenized_article_batch = map(tokenize(None), raw_article_batch)
                #processed_clusters = map(preproc(list(tokenized_article_batch), clusters))
                #processed_clusters = list(zip(*processed_clusters))
                ext_arts = []
                ext_inds = []
                pre_abs = []
                beam_inds = []
                for raw_art_sents, raw_cls in zip(tokenized_article_batch, clusters):
                    processed_clusters = preproc(raw_art_sents, raw_cls)
                    ext = extractor((raw_art_sents, processed_clusters))[:]  # exclude EOE
                    sent_selected += len(ext)
                    if not ext:
                        # use top-3 if nothing is extracted
                        # in some rare cases rnn-ext does not extract at all
                        ext = list(range(3))[:len(raw_art_sents)]
                    else:
                        ext = [i for i in ext]
                    ext_art = list(map(lambda i: raw_art_sents[i], ext))
                    pre_abs.append([word for sent in ext_art for word in sent])
                    beam_inds += [(len(beam_inds), 1)]

                if beam_size > 1:
                    # all_beams = abstractor(ext_arts, beam_size, diverse)
                    # dec_outs = rerank_mp(all_beams, ext_inds)
                    all_beams = abstractor(pre_abs, beam_size, diverse=1.0)
                    dec_outs = rerank_mp(all_beams, beam_inds)
                else:
                    dec_outs = abstractor(pre_abs)
                for dec_out in dec_outs:
                    dec_out = sent_tokenize(' '.join(dec_out))
                    ext = [sent.split(' ') for sent in dec_out]
                    ext_inds += [(len(ext_arts), len(ext))]
                    ext_arts += ext

                dec_outs = ext_arts
                assert i == batch_size*i_debug
                for j, n in ext_inds:
                    decoded_sents = [' '.join(dec) for dec in dec_outs[j:j+n]]
                    with open(join(save_path, 'output/{}.dec'.format(i)),
                              'w') as f:
                        f.write(make_html_safe('\n'.join(decoded_sents)))
                    i += 1
                    print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                        i, n_data, i/n_data*100,
                        timedelta(seconds=int(time()-start))
                    ), end='')
                    length += len(decoded_sents)
    else:
        i = 0
        length = 0
        with torch.no_grad():
            for i_debug, raw_article_batch in enumerate(loader):
                tokenized_article_batch = map(tokenize(None), raw_article_batch)
                ext_arts = []
                ext_inds = []
                for raw_art_sents in tokenized_article_batch:
                    ext = extractor(raw_art_sents)[:-1]  # exclude EOE
                    if not ext:
                        # use top-5 if nothing is extracted
                        # in some rare cases rnn-ext does not extract at all
                        ext = list(range(5))[:len(raw_art_sents)]
                    else:
                        ext = [i.item() for i in ext]
                    ext_inds += [(len(ext_arts), len(ext))]
                    ext_arts += [raw_art_sents[i] for i in ext]
                if beam_size > 1:
                    all_beams = abstractor(ext_arts, beam_size, diverse)
                    dec_outs = rerank_mp(all_beams, ext_inds)
                else:
                    dec_outs = abstractor(ext_arts)
                assert i == batch_size*i_debug
                for j, n in ext_inds:
                    decoded_sents = [' '.join(dec) for dec in dec_outs[j:j+n]]
                    with open(join(save_path, 'output/{}.dec'.format(i)),
                              'w') as f:
                        f.write(make_html_safe('\n'.join(decoded_sents)))
                    i += 1
                    print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                        i, n_data, i/n_data*100,
                        timedelta(seconds=int(time()-start))
                    ), end='')
                    length += len(decoded_sents)
    print('average summary length:', length / i)
    print('average sentence selected:', sent_selected)
Esempio n. 14
0
def decode(save_path, model_dir, split, batch_size,
           beam_size, diverse, max_len, cuda):
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = lambda x,y:x
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'),
                                    max_len, cuda)
        else:
            print('BEAM')
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'),
                                        max_len, cuda)
    extractor = RLExtractor(model_dir, cuda=cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles
    dataset = DecodeDataset(split)

    n_data = len(dataset)
    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=4,
        collate_fn=coll
    )

    # prepare save paths and logs
    try:
        os.makedirs(join(save_path, 'output'))
    except:
        pass
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    total_leng = 0
    total_num = 0
    with torch.no_grad():
        for i_debug, data_batch in enumerate(loader):
            raw_article_batch, sent_label_batch = tuple(map(list, unzip(data_batch)))
            tokenized_article_batch = map(tokenize(None), raw_article_batch)
            #ext_arts = []
            ext_inds = []
            dirty = []
            ext_sents = []
            masks = []
            for raw_art_sents, sent_labels in zip(tokenized_article_batch, sent_label_batch):
                ext = extractor(raw_art_sents, sent_labels)  # exclude EOE

                tmp_size = min(max_dec_edu, len(ext) - 1)
                #total_leng += sum([len(e) -1 for e in ext[:-1]])
                #total_num += len(ext) - 1
                #print(tmp_size, len(ext) - 1)
                ext_inds += [(len(ext_sents), tmp_size)]
                tmp_stop = ext[-1][-1].item()
                tmp_truncate = tmp_stop - 1
                str_arts = list(map(lambda x: ' '.join(x), raw_art_sents))
                for idx in ext[:tmp_size]:
                    t, m = rl_edu_to_sentence(str_arts, idx)
                    total_leng += len(t)
                    total_num += 1
                    assert len(t) == len(m)
                    if t == []:
                        assert len(idx) == 1
                        id = idx[0].item()
                        if id == tmp_truncate:
                            dirty.append(len(ext_sents))
                            ext_sents.append(label)
                            masks.append(label_mask)
                    else:
                        if idx[-1].item() != tmp_stop:
                            ext_sents.append(t)
                            masks.append(m)


                #ext_arts += [raw_art_sents[i] for i in ext]
            #print(ext_sents)
            #print(masks)
            #print(dirty)
            #exit(0)
            if beam_size > 1:
                #print(ext_sents)
                #print(masks)
                all_beams = abstractor(ext_sents, masks, beam_size, diverse)
                print('rerank')
                dec_outs = rerank_mp(all_beams, ext_inds)
                for d in dirty:
                    dec_outs[d] = []
                # TODO:!!!!!!!!!!!
            else:
                dec_outs = abstractor(ext_sents, masks)
                for d in dirty:
                    dec_outs[d] = []
            assert i == batch_size*i_debug
            for j, n in ext_inds:
                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j+n]]
                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    f.write(make_html_safe('\n'.join(decoded_sents)))
                if i % 100 == 0:
                    print(total_leng / total_num)
                i += 1

                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i/n_data*100,
                    timedelta(seconds=int(time()-start))
                ), end='')
    print()
Esempio n. 15
0
model_dir_ja = Config.model_dir_ja
model_dir_en = Config.model_dir_en
beam_size = Config.beam_size
max_len = Config.max_len
cuda = Config.cuda

with open(join(model_dir_ja, 'meta.json')) as f:
    meta_ja = json.loads(f.read())
if meta_ja['net_args']['abstractor'] is None:
    # NOTE: if no abstractor is provided then
    #       the whole model would be extractive summarization
    assert beam_size == 1
    abstractor_ja = identity
else:
    if beam_size == 1:
        abstractor_ja = Abstractor(join(model_dir_ja, 'abstractor'), max_len,
                                   cuda)
    else:
        abstractor_ja = BeamAbstractor(join(model_dir_ja, 'abstractor'),
                                       max_len, cuda)
extractor_ja = RLExtractor(model_dir_ja, cuda=cuda)

with open(join(model_dir_en, 'meta.json')) as f:
    meta_en = json.loads(f.read())
if meta_en['net_args']['abstractor'] is None:
    # NOTE: if no abstractor is provided then
    #       the whole model would be extractive summarization
    assert beam_size == 1
    abstractor_en = identity
else:
    if beam_size == 1:
        abstractor_en = Abstractor(join(model_dir_en, 'abstractor'), max_len,
def decode(save_path, model_dir, split, batch_size, beam_size, diverse,
           max_len, cuda):
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    """
    print('Docoding extraction result....')
    abstractor = identity
    """
    if meta['net_args']['abstractor'] is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            print('Decoding full model result with 1 beamsize')
            abstractor = Abstractor(join(model_dir, 'abstractor'), max_len,
                                    cuda)
        else:
            print('Decoding full model result with {} beamsize'.format(
                beam_size))
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len,
                                        cuda)

    extractor = RLExtractor(model_dir, cuda=cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles

    dataset = DecodeDataset(split)

    n_data = len(dataset)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=coll)

    # prepare save paths and logs
    os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    filename = 'extracted_index.txt'
    with torch.no_grad():
        #f = open(os.path.join(save_path,filename), 'w')
        #f.close()
        for i_debug, batch_data in enumerate(loader):
            #tokenized_article_batch = map(tokenize(None), raw_article_batch)
            batch_data = map(tokenize_decode(None), batch_data)
            ext_arts = []
            ext_inds = []
            for data in batch_data:
                raw_art_sents, topic = data
                ext = extractor(raw_art_sents, topic)[:-1]  # exclude EOE
                if not ext:
                    ## use top-5 if nothing is extracted
                    ## in some rare cases rnn-ext does not extract at all
                    ext = list(range(1))[:len(
                        raw_art_sents)]  # YUNZHU change from 5 to 1
                    ## if want the extractor result   !!!  ######################
                    #with open(os.path.join(save_path, filename), 'a') as f:
                    #    line = [str(i) for i in ext]+['\n']
                    #    f.writelines(line)
                    #print(i)
                    ############################################################
                else:
                    ext = [i.item() for i in ext]
                    ## if want the extractor result   !!!  #####################
                    #with open('save_decode_extract/'+filename, 'a') as f:
                    #    line = [str(i) for i in ext]+['\n']
                    #    f.writelines(line)
                    #print(i)

#############################################################
#pdb.set_trace()
#i+=1

                ext_inds += [(len(ext_arts), len(ext))]
                ext_arts += [raw_art_sents[i] for i in ext]

            if beam_size > 1:
                all_beams = abstractor(ext_arts, beam_size, diverse)
                dec_outs = rerank_mp(all_beams, ext_inds)
            else:
                dec_outs = abstractor(ext_arts)
            assert i == batch_size * i_debug

            for j, n in ext_inds:

                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]]
                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    f.write(make_html_safe('\n'.join(decoded_sents)))
                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i / n_data * 100,
                    timedelta(seconds=int(time() - start))),
                      end='')

    print()
Esempio n. 17
0
def decode(save_path, save_file, model_dir, split, batch_size, beam_size,
           diverse, max_len, cuda):
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'), max_len,
                                    cuda)
        else:
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len,
                                        cuda)
    extractor = RLExtractor(model_dir, cuda=cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles

    dataset = DecodeDataset(split)

    n_data = len(dataset)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=coll)

    # prepare save paths and logs
    os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            tokenized_article_batch = map(tokenize(None), raw_article_batch)
            ext_arts = []
            ext_inds = []
            for raw_art_sents in tokenized_article_batch:
                ext = extractor(raw_art_sents)[:-1]  # exclude EOE
                if not ext:
                    # use top-5 if nothing is extracted
                    # in some rare cases rnn-ext does not extract at all
                    ext = list(range(5))[:len(raw_art_sents)]
                else:
                    ext = [i.item() for i in ext]
                ext_inds += [(len(ext_arts), len(ext))]
                ext_arts += [raw_art_sents[i] for i in ext]
            if beam_size > 1:
                all_beams = abstractor(ext_arts, beam_size, diverse)
                dec_outs = rerank_mp(all_beams, ext_inds)
            else:
                dec_outs = abstractor(ext_arts)
            assert i == batch_size * i_debug
            for j, n in ext_inds:
                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]]
                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    f.write(make_html_safe('\n'.join(decoded_sents)))
                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i / n_data * 100,
                    timedelta(seconds=int(time() - start))),
                      end='')
    print()

    #not needed for cnn/dailymail dataset probably
    f = open(save_file, "w")
    summaries_files = os.listdir(join(save_path, 'output'))
    n = len(summaries_files)
    summaries_list = [""] * n

    for fname in summaries_files:
        num = int(fname.replace(".dec", ""))
        f_local = open(join(save_path, "output", fname))
        summaries_list[num] = f_local.read().replace("\n", " ")
        f_local.close()

    assert (len(summaries_list) == n)

    f.write("\n".join(summaries_list))
    f.close()
Esempio n. 18
0
def decode(save_path, abs_dir, ext_dir, split, batch_size, max_len, cuda, trans=False):
    start = time()
    # setup model
    if abs_dir is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        abstractor = identity
    else:
        abstractor = Abstractor(abs_dir, max_len, cuda)
    if ext_dir is None:
        # NOTE: if no abstractor is provided then
        #       it would be  the lead-N extractor
        extractor = lambda art_sents: list(range(len(art_sents)))[:MAX_ABS_NUM]
    else:
        extractor = Extractor(ext_dir, max_ext=MAX_ABS_NUM, cuda=cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles
    dataset = DecodeDataset(split)

    n_data = len(dataset)
    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=4,
        collate_fn=coll
    )

    # prepare save paths and logs
    for i in range(MAX_ABS_NUM):
        os.makedirs(join(save_path, 'output_{}'.format(i)))
    # os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = (None if abs_dir is None
                             else json.load(open(join(abs_dir, 'meta.json'))))
    dec_log['extractor'] = (None if ext_dir is None
                            else json.load(open(join(ext_dir, 'meta.json'))))
    dec_log['rl'] = False
    dec_log['split'] = split
    dec_log['beam'] = 1  # greedy decoding only
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            if trans:
                tokenized_article_batch = raw_article_batch #
            else:
                tokenized_article_batch = map(tokenize(None), raw_article_batch)
            ext_arts = []
            ext_inds = []
            for raw_art_sents in tokenized_article_batch:
                if trans:
                    ext, batch = extractor(raw_art_sents)
                    art_sents = batch.src_str[0]
                    # print(ext, [x.nonzero(as_tuple=True)[0] for x in batch.src_sent_labels])
                    for k, idx in enumerate([ext]):
                        _pred = []
                        _ids = []
                        if (len(batch.src_str[k]) == 0):
                            continue
                        for j in idx[:min(len(ext), len(batch.src_str[k]))]:
                            if (j >= len(batch.src_str[k])):
                                continue
                            candidate = batch.src_str[k][j].strip()
                            if (not _block_tri(candidate, _pred)):
                                _pred.append(candidate)
                                _ids.append(j)
                            else:
                                continue

                            if (len(_pred) == 3):
                                break
                    # print(ext, _ids, [x.nonzero(as_tuple=True)[0] for x in batch.src_sent_labels], list(map(lambda i: art_sents[i], ext)))
                    ext = _ids
                    ext_inds += [(len(ext_arts), len(ext))]
                    ext_arts += list(map(lambda i: art_sents[i], ext))
                else:
                    ext = extractor(raw_art_sents)
                    ext_inds += [(len(ext_arts), len(ext))]
                    ext_arts += list(map(lambda i: raw_art_sents[i], ext))
            dec_outs = abstractor(ext_arts)
            # print(dec_outs)
            assert i == batch_size*i_debug
            for j, n in ext_inds:
                if trans:
                    decoded_sents = dec_outs[j:j+n]
                else:
                    decoded_sents = [' '.join(dec) for dec in dec_outs[j:j+n]]
                for k, dec_str in enumerate(decoded_sents):
                    with open(join(save_path, 'output_{}/{}.dec'.format(k, i)),
                          'w') as f:
                        f.write(make_html_safe(dec_str)) #f.write(make_html_safe('\n'.join(decoded_sents)))

                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i/n_data*100, timedelta(seconds=int(time()-start))
                ), end='')
            # if i_debug == 1:
                # break
    print()
Esempio n. 19
0
def decode(save_path, model_dir, split, batch_size, beam_size, diverse,
           max_len, cuda):
    start = time()

    if beam_size == 1:
        abstractor = Abstractor(join(model_dir, 'abstractor'), max_len, cuda)
    else:
        abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len,
                                    cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        articles = [" ".join(article) for article in articles]
        return articles

    dataset = DecodeDataset(args.data_path, split)

    n_data = len(dataset)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=coll)

    # prepare save paths and logs
    os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            tokenized_article_batch = tokenize(1000, raw_article_batch)
            batch_size = len(tokenized_article_batch)

            ext_inds = []
            for num in range(batch_size):
                ext_inds += [(num, 1)]
            if beam_size > 1:
                all_beams = abstractor(tokenized_article_batch, beam_size,
                                       diverse)
                dec_outs = rerank_mp(all_beams, ext_inds)
            else:
                dec_outs = abstractor(tokenized_article_batch)

            assert i == batch_size * i_debug

            for index in range(batch_size):

                decoded_sents = [
                    ' '.join(dec.split(",")) for dec in dec_outs[index]
                ]
                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    f.write(make_html_safe(' '.join(decoded_sents)))
                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i / n_data * 100,
                    timedelta(seconds=int(time() - start))),
                      end='')
    print()
Esempio n. 20
0
def decode(args):
    save_path = args.path
    model_dir = args.model_dir
    batch_size = args.batch
    beam_size = args.beam
    diverse = args.div
    max_len = args.max_dec_word
    cuda = args.cuda
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'), max_len,
                                    cuda)
        else:
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len,
                                        cuda)
    extractor = RLExtractor(model_dir, cuda=cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles

    dataset = DecodeDataset(args)

    n_data = len(dataset)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=coll)

    # prepare save paths and logs
    if not os.path.exists(join(save_path, 'output')):
        os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = args.mode
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            tokenized_article_batch = map(tokenize(None), raw_article_batch)
            ext_arts = []
            ext_inds = []
            for raw_art_sents in tokenized_article_batch:
                ext = extractor(raw_art_sents)[:-1]  # exclude EOE
                if not ext:
                    # use top-5 if nothing is extracted
                    # in some rare cases rnn-ext does not extract at all
                    ext = list(range(5))[:len(raw_art_sents)]
                else:
                    ext = [i.item() for i in ext]
                ext_inds += [(len(ext_arts), len(ext))]
                ext_arts += [raw_art_sents[i] for i in ext]
            if beam_size > 1:
                all_beams = abstractor(ext_arts, beam_size, diverse)
                dec_outs = rerank_mp(all_beams, ext_inds)
            else:
                dec_outs = abstractor(ext_arts)
            assert i == batch_size * i_debug
            for j, n in ext_inds:
                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]]
                decoded_sents = decoded_sents[:20]
                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    f.write(make_html_safe('\n'.join(decoded_sents)))
                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i / n_data * 100,
                    timedelta(seconds=int(time() - start))),
                      end='')
    print()
Esempio n. 21
0
def decode(save_path,
           model_dir,
           split,
           batch_size,
           beam_size,
           diverse,
           max_len,
           cuda,
           bart=False,
           clip=-1,
           tri_block=False):
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'), max_len,
                                    cuda)
        else:
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len,
                                        cuda)
    extractor = RLExtractor(model_dir, cuda=cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles

    dataset = DecodeDataset(split)

    n_data = len(dataset)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=coll)

    # prepare save paths and logs
    os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            # raw_article_batch
            tokenized_article_batch = map(tokenize(None),
                                          [r[0] for r in raw_article_batch])
            tokenized_abs_batch = map(tokenize(None),
                                      [r[1] for r in raw_article_batch])
            ext_arts = []
            ext_inds = []
            for raw_art_sents, raw_abs_sents in zip(tokenized_article_batch,
                                                    tokenized_abs_batch):
                ext, raw_art_sents = extractor(raw_art_sents,
                                               raw_abs_sents=raw_abs_sents)
                # print(raw_art_sen/ts)
                ext = ext[:-1]  # exclude EOE
                # print(ext)
                if tri_block:
                    _pred = []
                    _ids = []
                    for j in ext:
                        if (j >= len(raw_art_sents)):
                            continue
                        candidate = " ".join(raw_art_sents[j]).strip()
                        if (not _block_tri(candidate, _pred)):
                            _pred.append(candidate)
                            _ids.append(j)
                        else:
                            continue

                        if (len(_pred) == 3):
                            break
                    ext = _ids
                    # print(_pred)
                if clip > 0 and len(
                        ext) > clip:  #ADDED FOR CLIPPING, CHANGE BACK
                    # print("hi", clip)
                    ext = ext[0:clip]
                if not ext:
                    # use top-5 if nothing is extracted
                    # in some rare cases rnn-ext does not extract at all
                    ext = list(range(5))[:len(raw_art_sents)]
                else:
                    ext = [i.item() for i in ext]
                    # print(ext)
                ext_inds += [(len(ext_arts), len(ext))]
                ext_arts += [raw_art_sents[i] for i in ext]
            if bart:
                # print("hi")
                dec_outs = get_bart_summaries(ext_arts,
                                              tokenizer,
                                              bart_model,
                                              beam_size=beam_size)
            else:
                if beam_size > 1:
                    all_beams = abstractor(ext_arts, beam_size, diverse)
                    dec_outs = rerank_mp(all_beams, ext_inds)
                else:
                    dec_outs = abstractor(ext_arts)
            # print(dec_outs, i, i_debug)
            assert i == batch_size * i_debug
            for j, n in ext_inds:
                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]]
                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    f.write(make_html_safe('\n'.join(decoded_sents)))
                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i / n_data * 100,
                    timedelta(seconds=int(time() - start))),
                      end='')
    print()
def decode_graph(save_path, model_dir, split, batch_size, beam_size, diverse,
                 max_len, cuda, sc, min_len, docgraph, paragraph):
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
        #if not meta['net_args'].__contains__('abstractor'):
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'), max_len,
                                    cuda)
        else:
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'),
                                        max_len,
                                        cuda,
                                        min_len=min_len)

    print('docgraph:', docgraph)
    extractor = SCExtractor(model_dir,
                            cuda=cuda,
                            docgraph=docgraph,
                            paragraph=paragraph)
    adj_type = extractor._net._adj_type
    bert = extractor._net._bert
    if bert:
        tokenizer = extractor._net._bert
        try:
            with open(
                    '/data/luyang/process-nyt/bert_tokenizaiton_aligns/robertaalign-base-cased.pkl',
                    'rb') as f:
                align = pickle.load(f)
        except FileNotFoundError:
            with open(
                    '/data2/luyang/process-nyt/bert_tokenizaiton_aligns/robertaalign-base-cased.pkl',
                    'rb') as f:
                align = pickle.load(f)

        try:
            with open(
                    '/data/luyang/process-cnn-dailymail/bert_tokenizaiton_aligns/robertaalign-base-cased.pkl',
                    'rb') as f:
                align2 = pickle.load(f)
        except FileNotFoundError:
            with open(
                    '/data2/luyang/process-cnn-dailymail/bert_tokenizaiton_aligns/robertaalign-base-cased.pkl',
                    'rb') as f:
                align2 = pickle.load(f)

        align.update(align2)

    # setup loader
    def coll(batch):
        batch = list(filter(bool, batch))
        return batch

    dataset = DecodeDatasetGAT(split, args.key)

    n_data = len(dataset)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=coll)

    # prepare save paths and logs
    os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding

    i = 0
    length = 0
    sent_selected = 0
    with torch.no_grad():
        for i_debug, raw_input_batch in enumerate(loader):
            raw_article_batch, nodes, edges, paras, subgraphs = zip(
                *raw_input_batch)
            if bert:
                art_sents = [[
                    tokenizer.tokenize(source_sent)
                    for source_sent in source_sents
                ] for source_sents in raw_article_batch]
                for _i in range(len(art_sents)):
                    art_sents[_i][0] = [tokenizer.bos_token] + art_sents[_i][0]
                    art_sents[_i][-1] = art_sents[_i][-1] + [
                        tokenizer.eos_token
                    ]
                truncated_word_nums = []
                word_nums = [[len(sent) for sent in art_sent]
                             for art_sent in art_sents]
                for word_num in word_nums:
                    truncated_word_num = []
                    total_count = 0
                    for num in word_num:
                        if total_count + num < args.max_dec_word:
                            truncated_word_num.append(num)
                        else:
                            truncated_word_num.append(args.max_dec_word -
                                                      total_count)
                            break
                        total_count += num
                    truncated_word_nums.append(truncated_word_num)
                sources = [
                    list(concat(art_sent))[:args.max_dec_word]
                    for art_sent in art_sents
                ]
            else:
                tokenized_article_batch = map(tokenize(None),
                                              raw_article_batch)
            #processed_clusters = map(preproc(list(tokenized_article_batch), clusters))
            #processed_clusters = list(zip(*processed_clusters))
            ext_arts = []
            ext_inds = []
            pre_abs = []
            beam_inds = []
            if bert:
                for raw_art_sents, source, art_sent, word_num, raw_nodes, raw_edges, raw_paras, raw_subgraphs in zip(
                        raw_article_batch, sources, art_sents,
                        truncated_word_nums, nodes, edges, paras, subgraphs):
                    processed_nodes = prepro_rl_graph_bert(
                        align, raw_art_sents, source, art_sent,
                        args.max_dec_word, raw_nodes, raw_edges, raw_paras,
                        raw_subgraphs, adj_type, docgraph)
                    _input = (raw_art_sents,
                              source) + processed_nodes + (word_num, )
                    ext = extractor(_input)[:]
                    sent_selected += len(ext)
                    if not ext:
                        # use top-3 if nothing is extracted
                        # in some rare cases rnn-ext does not extract at all
                        ext = list(range(3))[:len(raw_art_sents)]
                    else:
                        ext = [i for i in ext]
                    ext_art = list(map(lambda i: raw_art_sents[i], ext))
                    pre_abs.append([word for sent in ext_art for word in sent])
                    beam_inds += [(len(beam_inds), 1)]

            else:
                for raw_art_sents, raw_nodes, raw_edges, raw_paras, raw_subgraphs in zip(
                        tokenized_article_batch, nodes, edges, paras,
                        subgraphs):
                    processed_nodes = prepro_rl_graph(raw_art_sents, raw_nodes,
                                                      raw_edges, raw_paras,
                                                      raw_subgraphs, adj_type,
                                                      docgraph)
                    _input = (raw_art_sents, ) + processed_nodes

                    ext = extractor(_input)[:]  # exclude EOE
                    sent_selected += len(ext)
                    if not ext:
                        # use top-3 if nothing is extracted
                        # in some rare cases rnn-ext does not extract at all
                        ext = list(range(3))[:len(raw_art_sents)]
                    else:
                        ext = [i for i in ext]
                    ext_art = list(map(lambda i: raw_art_sents[i], ext))
                    pre_abs.append([word for sent in ext_art for word in sent])
                    beam_inds += [(len(beam_inds), 1)]

            if beam_size > 1:
                # all_beams = abstractor(ext_arts, beam_size, diverse)
                # dec_outs = rerank_mp(all_beams, ext_inds)
                all_beams = abstractor(pre_abs, beam_size, diverse=1.0)
                dec_outs = rerank_mp(all_beams, beam_inds)
            else:
                dec_outs = abstractor(pre_abs)
            for dec_out in dec_outs:
                dec_out = sent_tokenize(' '.join(dec_out))
                ext = [sent.split(' ') for sent in dec_out]
                ext_inds += [(len(ext_arts), len(ext))]
                ext_arts += ext
            dec_outs = ext_arts
            assert i == batch_size * i_debug
            for j, n in ext_inds:
                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]]
                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    f.write(make_html_safe('\n'.join(decoded_sents)))
                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i / n_data * 100,
                    timedelta(seconds=int(time() - start))),
                      end='')
                length += len(decoded_sents)
    print('average summary length:', length / i)
    print('average sentence selected:', sent_selected)
def decode(save_path, model_dir, split, batch_size, beam_size, diverse,
           max_len, cuda):
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'), max_len,
                                    cuda)
        else:
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len,
                                        cuda)
    extractor = RLExtractor(model_dir, cuda=cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles

    dataset = DecodeDataset(split)

    n_data = len(dataset)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=coll)

    # prepare save paths and logs
    # os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    count = 0
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            tokenized_article_batch = map(tokenize(None), raw_article_batch)
            ext_arts = []
            ext_inds = []
            for raw_art_sents in tokenized_article_batch:
                ext = extractor(raw_art_sents)[:-1]  # exclude EOE
                if not ext:
                    # use top-5 if nothing is extracted
                    # in some rare cases rnn-ext does not extract at all
                    ext = list(range(5))[:len(raw_art_sents)]
                else:
                    ext = [i.item() for i in ext]
                ext_inds += [(len(ext_arts), len(ext))]
                ext_arts += [raw_art_sents[i] for i in ext]
            if beam_size > 1:
                all_beams = abstractor(ext_arts, beam_size, diverse)
                for ind_file, (start, finish) in enumerate(ext_inds):
                    article_beams = all_beams[start:start + finish]
                    file = {}
                    for ind_sent, sent in enumerate(article_beams):
                        file[ind_sent] = defaultdict(list)
                        sentence = " ".join(ext_arts[start + ind_sent])
                        file[ind_sent]['sentence'].append(sentence)
                        for hypothesis in sent:
                            file[ind_sent]['summarizer_logprob'].append(
                                hypothesis.logprob)
                            file[ind_sent]['hypotheses'].append(" ".join(
                                hypothesis.sequence))

                    with open(
                            os.path.join('exported_beams',
                                         '{}.json'.format(count + ind_file)),
                            'w') as f:
                        json.dump(file, f, ensure_ascii=False)
                count += batch_size