Exemple #1
0
def dump(data_dir, split):
    start = time()
    print('start processing {} split...'.format(split))
    data_dir = os.path.join(data_dir, split)
    dump_dir = os.path.join(data_dir, 'refs', split)
    n_data = count_data(data_dir)
    for i in range(n_data):
        print('processing {}/{} ({:.2f}%%)\r'.format(i, n_data, 100*i/n_data), end='')
        with open(os.path.join(data_dir, '{}.json'.format(i))) as f:
            data = json.loads(f.read())
        abs_sents = data['abstract']
        with open(os.path.join(dump_dir, '{}.ref'.format(i)), 'w') as f:
            f.write(make_html_safe('\n'.join(abs_sents)))
    print('finished in {}'.format(timedelta(seconds=time()-start)))
def decode_all(save_path, model_dir, split, batch_size, beam_size, diverse,
               max_len, cuda):
    start = time()
    # setup model
    model = Model(model_dir, beam_size, diverse, max_len, cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles

    dataset = DecodeDataset(split)

    n_data = len(dataset)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=coll)

    # prepare save paths and logs
    os.makedirs(join(save_path, 'output'), exist_ok=True)
    dec_log = {}
    dec_log['abstractor'] = model.meta['net_args']['abstractor']
    dec_log['extractor'] = model.meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = model.beam_size
    dec_log['diverse'] = model.diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            dec_outs, ext_inds = model.decode(raw_article_batch)
            assert i == batch_size * i_debug
            for j, n in ext_inds:
                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]]
                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    f.write(make_html_safe('\n'.join(decoded_sents)))
                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i / n_data * 100,
                    timedelta(seconds=int(time() - start))),
                      end='')
    print()
def dump(split):
    start = time()
    print('start processing {} split...'.format(split))
    data_dir = join(DATA_DIR, split)
    dump_dir = join(DATA_DIR, 'refs', split)
    n_data = count_data(data_dir)
    for i in range(n_data):
        print('processing {}/{} ({:.2f}%%)\r'.format(i, n_data, 100*i/n_data),
              end='')
        with open(join(data_dir, '{}.json'.format(i))) as f:
            data = json.loads(f.read())
        abs_sents = data['abstract']
        with open(join(dump_dir, '{}.ref'.format(i)), 'w') as f:
            f.write(make_html_safe('\n'.join(abs_sents)))
    print('finished in {}'.format(timedelta(seconds=time()-start)))
def dump(split):
    start = time()
    print('start processing {} split...'.format(split))
    data_dir = join(DATA_DIR, split)
    dump_dir = join(DATA_DIR, 'refs', split)
    n_data = count_data(data_dir)

    for i, file in enumerate(iter_files(data_dir)):
        print('processing {}/{} ({:.2f}%%)\r'.format(i, n_data,
                                                     100 * i / n_data),
              end='')
        name = os.path.basename(file)
        name, _ = os.path.splitext(name)

        with open(join(data_dir, '{}.json'.format(name))) as f:
            data = json.loads(f.read())
        abs_sents = data['abstract']
        with open(join(dump_dir, '{}.ref'.format(name)), 'w') as f:
            f.write(make_html_safe('\n'.join(abs_sents)))
    print('finished in {}'.format(timedelta(seconds=time() - start)))
def decode_graph(save_path, model_dir, split, batch_size, beam_size, diverse,
                 max_len, cuda, sc, min_len, docgraph, paragraph):
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
        #if not meta['net_args'].__contains__('abstractor'):
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'), max_len,
                                    cuda)
        else:
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'),
                                        max_len,
                                        cuda,
                                        min_len=min_len)

    print('docgraph:', docgraph)
    extractor = SCExtractor(model_dir,
                            cuda=cuda,
                            docgraph=docgraph,
                            paragraph=paragraph)
    adj_type = extractor._net._adj_type
    bert = extractor._net._bert
    if bert:
        tokenizer = extractor._net._bert
        try:
            with open(
                    '/data/luyang/process-nyt/bert_tokenizaiton_aligns/robertaalign-base-cased.pkl',
                    'rb') as f:
                align = pickle.load(f)
        except FileNotFoundError:
            with open(
                    '/data2/luyang/process-nyt/bert_tokenizaiton_aligns/robertaalign-base-cased.pkl',
                    'rb') as f:
                align = pickle.load(f)

        try:
            with open(
                    '/data/luyang/process-cnn-dailymail/bert_tokenizaiton_aligns/robertaalign-base-cased.pkl',
                    'rb') as f:
                align2 = pickle.load(f)
        except FileNotFoundError:
            with open(
                    '/data2/luyang/process-cnn-dailymail/bert_tokenizaiton_aligns/robertaalign-base-cased.pkl',
                    'rb') as f:
                align2 = pickle.load(f)

        align.update(align2)

    # setup loader
    def coll(batch):
        batch = list(filter(bool, batch))
        return batch

    dataset = DecodeDatasetGAT(split, args.key)

    n_data = len(dataset)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=coll)

    # prepare save paths and logs
    os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding

    i = 0
    length = 0
    sent_selected = 0
    with torch.no_grad():
        for i_debug, raw_input_batch in enumerate(loader):
            raw_article_batch, nodes, edges, paras, subgraphs = zip(
                *raw_input_batch)
            if bert:
                art_sents = [[
                    tokenizer.tokenize(source_sent)
                    for source_sent in source_sents
                ] for source_sents in raw_article_batch]
                for _i in range(len(art_sents)):
                    art_sents[_i][0] = [tokenizer.bos_token] + art_sents[_i][0]
                    art_sents[_i][-1] = art_sents[_i][-1] + [
                        tokenizer.eos_token
                    ]
                truncated_word_nums = []
                word_nums = [[len(sent) for sent in art_sent]
                             for art_sent in art_sents]
                for word_num in word_nums:
                    truncated_word_num = []
                    total_count = 0
                    for num in word_num:
                        if total_count + num < args.max_dec_word:
                            truncated_word_num.append(num)
                        else:
                            truncated_word_num.append(args.max_dec_word -
                                                      total_count)
                            break
                        total_count += num
                    truncated_word_nums.append(truncated_word_num)
                sources = [
                    list(concat(art_sent))[:args.max_dec_word]
                    for art_sent in art_sents
                ]
            else:
                tokenized_article_batch = map(tokenize(None),
                                              raw_article_batch)
            #processed_clusters = map(preproc(list(tokenized_article_batch), clusters))
            #processed_clusters = list(zip(*processed_clusters))
            ext_arts = []
            ext_inds = []
            pre_abs = []
            beam_inds = []
            if bert:
                for raw_art_sents, source, art_sent, word_num, raw_nodes, raw_edges, raw_paras, raw_subgraphs in zip(
                        raw_article_batch, sources, art_sents,
                        truncated_word_nums, nodes, edges, paras, subgraphs):
                    processed_nodes = prepro_rl_graph_bert(
                        align, raw_art_sents, source, art_sent,
                        args.max_dec_word, raw_nodes, raw_edges, raw_paras,
                        raw_subgraphs, adj_type, docgraph)
                    _input = (raw_art_sents,
                              source) + processed_nodes + (word_num, )
                    ext = extractor(_input)[:]
                    sent_selected += len(ext)
                    if not ext:
                        # use top-3 if nothing is extracted
                        # in some rare cases rnn-ext does not extract at all
                        ext = list(range(3))[:len(raw_art_sents)]
                    else:
                        ext = [i for i in ext]
                    ext_art = list(map(lambda i: raw_art_sents[i], ext))
                    pre_abs.append([word for sent in ext_art for word in sent])
                    beam_inds += [(len(beam_inds), 1)]

            else:
                for raw_art_sents, raw_nodes, raw_edges, raw_paras, raw_subgraphs in zip(
                        tokenized_article_batch, nodes, edges, paras,
                        subgraphs):
                    processed_nodes = prepro_rl_graph(raw_art_sents, raw_nodes,
                                                      raw_edges, raw_paras,
                                                      raw_subgraphs, adj_type,
                                                      docgraph)
                    _input = (raw_art_sents, ) + processed_nodes

                    ext = extractor(_input)[:]  # exclude EOE
                    sent_selected += len(ext)
                    if not ext:
                        # use top-3 if nothing is extracted
                        # in some rare cases rnn-ext does not extract at all
                        ext = list(range(3))[:len(raw_art_sents)]
                    else:
                        ext = [i for i in ext]
                    ext_art = list(map(lambda i: raw_art_sents[i], ext))
                    pre_abs.append([word for sent in ext_art for word in sent])
                    beam_inds += [(len(beam_inds), 1)]

            if beam_size > 1:
                # all_beams = abstractor(ext_arts, beam_size, diverse)
                # dec_outs = rerank_mp(all_beams, ext_inds)
                all_beams = abstractor(pre_abs, beam_size, diverse=1.0)
                dec_outs = rerank_mp(all_beams, beam_inds)
            else:
                dec_outs = abstractor(pre_abs)
            for dec_out in dec_outs:
                dec_out = sent_tokenize(' '.join(dec_out))
                ext = [sent.split(' ') for sent in dec_out]
                ext_inds += [(len(ext_arts), len(ext))]
                ext_arts += ext
            dec_outs = ext_arts
            assert i == batch_size * i_debug
            for j, n in ext_inds:
                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]]
                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    f.write(make_html_safe('\n'.join(decoded_sents)))
                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i / n_data * 100,
                    timedelta(seconds=int(time() - start))),
                      end='')
                length += len(decoded_sents)
    print('average summary length:', length / i)
    print('average sentence selected:', sent_selected)
def decode(save_path, abs_dir, ext_dir, split, batch_size, max_len, cuda):
    start = time()
    # setup model
    if abs_dir is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        abstractor = identity
    else:
        abstractor = Abstractor(abs_dir, max_len, cuda)
    if ext_dir is None:
        # NOTE: if no abstractor is provided then
        #       it would be  the lead-N extractor
        extractor = lambda art_sents: list(range(len(art_sents)))[:MAX_ABS_NUM]
    else:
        extractor = Extractor(ext_dir, max_ext=MAX_ABS_NUM, cuda=cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles
    dataset = DecodeDataset(split)

    n_data = len(dataset)
    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=4,
        collate_fn=coll
    )

    # prepare save paths and logs
    for i in range(MAX_ABS_NUM):
        os.makedirs(join(save_path, 'output_{}'.format(i)))
    dec_log = {}
    dec_log['abstractor'] = (None if abs_dir is None
                             else json.load(open(join(abs_dir, 'meta.json'))))
    dec_log['extractor'] = (None if ext_dir is None
                            else json.load(open(join(ext_dir, 'meta.json'))))
    dec_log['rl'] = False
    dec_log['split'] = split
    dec_log['beam'] = 1  # greedy decoding only
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            tokenized_article_batch = map(tokenize(None), raw_article_batch)
            ext_arts = []
            ext_inds = []
            for raw_art_sents in tokenized_article_batch:
                ext = extractor(raw_art_sents)
                ext_inds += [(len(ext_arts), len(ext))]
                ext_arts += list(map(lambda i: raw_art_sents[i], ext))
            dec_outs = abstractor(ext_arts)
            assert i == batch_size*i_debug
            for j, n in ext_inds:
                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j+n]]
                for k, dec_str in enumerate(decoded_sents):
                    with open(join(save_path, 'output_{}/{}.dec'.format(k, i)),
                              'w') as f:
                        f.write(make_html_safe(dec_str))

                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i/n_data*100, timedelta(seconds=int(time()-start))
                ), end='')
    print()
def decode(save_path, abs_dir, ext_dir, split, batch_size, max_len, cuda, trans=False):
    start = time()
    # setup model
    if abs_dir is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        abstractor = identity
    else:
        abstractor = Abstractor(abs_dir, max_len, cuda)
    if ext_dir is None:
        # NOTE: if no abstractor is provided then
        #       it would be  the lead-N extractor
        extractor = lambda art_sents: list(range(len(art_sents)))[:MAX_ABS_NUM]
    else:
        extractor = Extractor(ext_dir, max_ext=MAX_ABS_NUM, cuda=cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles
    dataset = DecodeDataset(split)

    n_data = len(dataset)
    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=4,
        collate_fn=coll
    )

    # prepare save paths and logs
    for i in range(MAX_ABS_NUM):
        os.makedirs(join(save_path, 'output_{}'.format(i)))
    # os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = (None if abs_dir is None
                             else json.load(open(join(abs_dir, 'meta.json'))))
    dec_log['extractor'] = (None if ext_dir is None
                            else json.load(open(join(ext_dir, 'meta.json'))))
    dec_log['rl'] = False
    dec_log['split'] = split
    dec_log['beam'] = 1  # greedy decoding only
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            if trans:
                tokenized_article_batch = raw_article_batch #
            else:
                tokenized_article_batch = map(tokenize(None), raw_article_batch)
            ext_arts = []
            ext_inds = []
            for raw_art_sents in tokenized_article_batch:
                if trans:
                    ext, batch = extractor(raw_art_sents)
                    art_sents = batch.src_str[0]
                    # print(ext, [x.nonzero(as_tuple=True)[0] for x in batch.src_sent_labels])
                    for k, idx in enumerate([ext]):
                        _pred = []
                        _ids = []
                        if (len(batch.src_str[k]) == 0):
                            continue
                        for j in idx[:min(len(ext), len(batch.src_str[k]))]:
                            if (j >= len(batch.src_str[k])):
                                continue
                            candidate = batch.src_str[k][j].strip()
                            if (not _block_tri(candidate, _pred)):
                                _pred.append(candidate)
                                _ids.append(j)
                            else:
                                continue

                            if (len(_pred) == 3):
                                break
                    # print(ext, _ids, [x.nonzero(as_tuple=True)[0] for x in batch.src_sent_labels], list(map(lambda i: art_sents[i], ext)))
                    ext = _ids
                    ext_inds += [(len(ext_arts), len(ext))]
                    ext_arts += list(map(lambda i: art_sents[i], ext))
                else:
                    ext = extractor(raw_art_sents)
                    ext_inds += [(len(ext_arts), len(ext))]
                    ext_arts += list(map(lambda i: raw_art_sents[i], ext))
            dec_outs = abstractor(ext_arts)
            # print(dec_outs)
            assert i == batch_size*i_debug
            for j, n in ext_inds:
                if trans:
                    decoded_sents = dec_outs[j:j+n]
                else:
                    decoded_sents = [' '.join(dec) for dec in dec_outs[j:j+n]]
                for k, dec_str in enumerate(decoded_sents):
                    with open(join(save_path, 'output_{}/{}.dec'.format(k, i)),
                          'w') as f:
                        f.write(make_html_safe(dec_str)) #f.write(make_html_safe('\n'.join(decoded_sents)))

                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i/n_data*100, timedelta(seconds=int(time()-start))
                ), end='')
            # if i_debug == 1:
                # break
    print()
def decode(save_path, abs_dir, ext_dir, split, batch_size, max_len, cuda):
    start = time()
    # setup model
    if abs_dir is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        abstractor = identity
    else:
        abstractor = Abstractor(abs_dir, max_len, cuda)
    if ext_dir is None:
        # NOTE: if no abstractor is provided then
        #       it would be  the lead-N extractor
        extractor = lambda art_sents: list(range(len(art_sents)))[:MAX_ABS_NUM]
    else:
        extractor = Extractor(ext_dir, max_ext=MAX_ABS_NUM, cuda=cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles
    dataset = DecodeDataset(split)

    n_data = len(dataset)
    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=4,
        collate_fn=coll
    )

    # prepare save paths and logs
    for i in range(MAX_ABS_NUM):
        os.makedirs(join(save_path, 'output_{}'.format(i)))
    dec_log = {}
    dec_log['abstractor'] = (None if abs_dir is None
                             else json.load(open(join(abs_dir, 'meta.json'))))
    dec_log['extractor'] = (None if ext_dir is None
                            else json.load(open(join(ext_dir, 'meta.json'))))
    dec_log['rl'] = False
    dec_log['split'] = split
    dec_log['beam'] = 1  # greedy decoding only
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            tokenized_article_batch = map(tokenize(None), raw_article_batch)
            ext_arts = []
            ext_inds = []
            for raw_art_sents in tokenized_article_batch:
                ext = extractor(raw_art_sents)
                ext_inds += [(len(ext_arts), len(ext))]
                ext_arts += list(map(lambda i: raw_art_sents[i], ext))
            dec_outs = abstractor(ext_arts)
            assert i == batch_size*i_debug
            for j, n in ext_inds:
                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j+n]]
                for k, dec_str in enumerate(decoded_sents):
                    with open(join(save_path, 'output_{}/{}.dec'.format(k, i)),
                              'w') as f:
                        f.write(make_html_safe(dec_str))

                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i/n_data*100, timedelta(seconds=int(time()-start))
                ), end='')
    print()
Exemple #9
0
def decode(save_path, abs_dir, ext_dir, split, batch_size, max_len, cuda,
           min_len):
    start = time()
    # setup model
    if abs_dir is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        abstractor = identity
    else:
        #abstractor = Abstractor(abs_dir, max_len, cuda)
        abstractor = BeamAbstractor(abs_dir,
                                    max_len,
                                    cuda,
                                    min_len,
                                    reverse=args.reverse)
    if ext_dir is None:
        # NOTE: if no exstractor is provided then
        #       it would be  the lead-N extractor
        extractor = lambda art_sents: list(range(len(art_sents)))[:MAX_ABS_NUM]
    else:
        if args.no_force_ext:
            extractor = Extractor(ext_dir,
                                  max_ext=MAX_ABS_NUM,
                                  cuda=cuda,
                                  force_ext=not args.no_force_ext)
        else:
            extractor = Extractor(ext_dir, max_ext=MAX_ABS_NUM, cuda=cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles

    dataset = DecodeDataset(split)

    n_data = len(dataset)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=coll)

    os.makedirs(save_path)
    # prepare save paths and logs
    dec_log = {}
    dec_log['abstractor'] = (None if abs_dir is None else json.load(
        open(join(abs_dir, 'meta.json'))))
    dec_log['extractor'] = (None if ext_dir is None else json.load(
        open(join(ext_dir, 'meta.json'))))
    dec_log['rl'] = False
    dec_log['split'] = split
    if abs_dir is not None:
        dec_log['beam'] = 5  # greedy decoding only
        beam_size = 5
    else:
        dec_log['beam'] = 1
        beam_size = 1
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)
    print(dec_log['extractor'])
    if dec_log['extractor']['net_args'][
            'stop'] == False and not args.no_force_ext:
        for i in range(MAX_ABS_NUM + 1):
            os.makedirs(join(save_path, 'output_{}'.format(i)))
    else:
        os.makedirs(join(save_path, 'output'))

    # Decoding
    i = 0
    length = 0
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            tokenized_article_batch = map(tokenize(None), raw_article_batch)
            ext_arts = []
            ext_inds = []
            pre_abs = []
            beam_inds = []
            for raw_art_sents in tokenized_article_batch:
                ext = extractor(raw_art_sents)
                ext_art = list(map(lambda i: raw_art_sents[i], ext))
                pre_abs.append([word for sent in ext_art for word in sent])
                beam_inds += [(len(beam_inds), 1)]

            if beam_size > 1:
                all_beams = abstractor(pre_abs, beam_size, diverse=1.0)
                dec_outs = rerank_mp(all_beams, beam_inds)
            else:
                dec_outs = abstractor(pre_abs)

            for dec_out in dec_outs:
                dec_out = sent_tokenize(' '.join(dec_out))
                ext = [sent.split(' ') for sent in dec_out]
                ext_inds += [(len(ext_arts), len(ext))]
                ext_arts += ext
            if dec_log['extractor']['net_args'][
                    'stop'] == False and not args.no_force_ext:
                dec_outs = ext_arts
                assert i == batch_size * i_debug
                for j, n in ext_inds:
                    decoded_sents = [
                        ' '.join(dec) for dec in dec_outs[j:j + n]
                    ]
                    for k, dec_str in enumerate(decoded_sents):
                        if k > MAX_ABS_NUM - 2:
                            break
                        with open(
                                join(save_path,
                                     'output_{}/{}.dec'.format(k, i)),
                                'w') as f:
                            f.write(make_html_safe(dec_str))
                    i += 1
                    print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                        i, n_data, i / n_data * 100,
                        timedelta(seconds=int(time() - start))),
                          end='')
            else:
                dec_outs = ext_arts
                assert i == batch_size * i_debug
                for j, n in ext_inds:
                    decoded_sents = [
                        ' '.join(dec) for dec in dec_outs[j:j + n]
                    ]
                    with open(join(save_path, 'output/{}.dec'.format(i)),
                              'w') as f:
                        f.write(make_html_safe('\n'.join(decoded_sents)))
                    i += 1
                    print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                        i, n_data, i / n_data * 100,
                        timedelta(seconds=int(time() - start))),
                          end='')
                    length += len(decoded_sents)
        print('average summary length:', length / i)
Exemple #10
0
def decode(save_path, model_dir, split, batch_size, beam_size, diverse,
           max_len, cuda):
    start = time()

    if beam_size == 1:
        abstractor = Abstractor(join(model_dir, 'abstractor'), max_len, cuda)
    else:
        abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len,
                                    cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        articles = [" ".join(article) for article in articles]
        return articles

    dataset = DecodeDataset(args.data_path, split)

    n_data = len(dataset)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=coll)

    # prepare save paths and logs
    os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            tokenized_article_batch = tokenize(1000, raw_article_batch)
            batch_size = len(tokenized_article_batch)

            ext_inds = []
            for num in range(batch_size):
                ext_inds += [(num, 1)]
            if beam_size > 1:
                all_beams = abstractor(tokenized_article_batch, beam_size,
                                       diverse)
                dec_outs = rerank_mp(all_beams, ext_inds)
            else:
                dec_outs = abstractor(tokenized_article_batch)

            assert i == batch_size * i_debug

            for index in range(batch_size):

                decoded_sents = [
                    ' '.join(dec.split(",")) for dec in dec_outs[index]
                ]
                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    f.write(make_html_safe(' '.join(decoded_sents)))
                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i / n_data * 100,
                    timedelta(seconds=int(time() - start))),
                      end='')
    print()
def decode(save_path, model_dir, split, batch_size,
           beam_size, diverse, max_len, cuda):
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'),
                                    max_len, cuda)
        else:
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'),
                                        max_len, cuda)
    
    extractor = RLExtractor(model_dir, cuda=cuda)

    # setup loader
    def coll(batch):
        articles, abstract, extracted = unzip(batch)
        articles = list(filter(bool, articles))
        abstract = list(filter(bool, abstract))
        extracted =  list(filter(bool, extracted))
        return articles, abstract, extracted

    dataset = DecodeDataset(split)
    n_data = len(dataset[0]) # article sentence
    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=4,
        collate_fn=coll
    )
    # prepare save paths and logs
    if os.path.exists(join(save_path, 'output')):
        pass
    else:
        os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse

    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)
    
    file_path = os.path.join(save_path, 'Attention')
    act_path = os.path.join(save_path, 'Actions')

    header = "index,rouge_score1,rouge_score2,"+\
    "rouge_scorel,dec_sent_nums,abs_sent_nums,doc_sent_nums,doc_words_nums,"+\
    "ext_words_nums, abs_words_nums, diff,"+\
    "recall, precision, less_rewrite, preserve_action, rewrite_action, each_actions,"+\
    "top3AsAns, top3AsGold, any_top2AsAns, any_top2AsGold,true_rewrite,true_preserve\n"


    if not os.path.exists(file_path):
        print('create dir:{}'.format(file_path))
        os.makedirs(file_path)

    if not os.path.exists(act_path):
        print('create dir:{}'.format(act_path))
        os.makedirs(act_path)

    with open(join(save_path,'_statisticsDecode.log.csv'),'w') as w:
        w.write(header)  
        
    # Decoding
    i = 0
    with torch.no_grad():
        for i_debug, (raw_article_batch, raw_abstract_batch, extracted_batch) in enumerate(loader):
            tokenized_article_batch = map(tokenize(None), raw_article_batch)
            tokenized_abstract_batch = map(tokenize(None), raw_abstract_batch)
            token_nums_batch = list(map(token_nums(None), raw_article_batch))

            ext_nums = []
            ext_arts = []
            ext_inds = []
            rewrite_less_rouge = []
            dec_outs_act = []
            ext_acts = []
            abs_collections = []
            ext_collections = []

            # 抽句子
            for ind, (raw_art_sents, abs_sents) in enumerate(zip(tokenized_article_batch ,tokenized_abstract_batch)):

                (ext, (state, act_dists)), act = extractor(raw_art_sents)  # exclude EOE
                extracted_state = state[extracted_batch[ind]]
                attn = torch.softmax(state.mm(extracted_state.transpose(1,0)),dim=-1)
                # (_, abs_state), _ = extractor(abs_sents)  # exclude EOE
                
                def plot_actDist(actons, nums):
                    print('indiex: {} distribution ...'.format(nums))
                    # Write MDP State Attention weight matrix   
                    file_name = os.path.join(act_path, '{}.attention.pdf'.format(nums))
                    pdf_pages = PdfPages(file_name)
                    plot_attention(actons.cpu().numpy(), name='{}-th artcle'.format(nums),
                        X_label=list(range(len(raw_art_sents))), Y_label=list(range(len(ext))),
                        dirpath=save_path, pdf_page=pdf_pages,action=True)
                    pdf_pages.close()
                # plot_actDist(torch.stack(act_dists, dim=0), nums=ind+i)

                def plot_attn():
                    print('indiex: {} write_attention_pdf ...'.format(i + ind))
                    # Write MDP State Attention weight matrix   
                    file_name = os.path.join(file_path, '{}.attention.pdf'.format(i+ind))
                    pdf_pages = PdfPages(file_name)
                    plot_attention(attn.cpu().numpy(), name='{}-th artcle'.format(i+ind),
                        X_label=extracted_batch[ind],Y_label=list(range(len(raw_art_sents))),
                        dirpath=save_path, pdf_page=pdf_pages) 
                    pdf_pages.close()
                # plot_attn()

                ext = ext[:-1]
                act = act[:-1]

                if not ext:
                    # use top-5 if nothing is extracted
                    # in some rare cases rnn-ext does not extract at all
                    ext = list(range(5))[:len(raw_art_sents)]
                    act = list([1]*5)[:len(raw_art_sents)]
                else:
                    ext = [i.item() for i in ext]
                    act = [i.item() for i in act]

                ext_nums.append(ext)

                ext_inds += [(len(ext_arts), len(ext))] # [(0,5),(5,7),(7,3),...]
                ext_arts += [raw_art_sents[k] for k in ext]
                ext_acts += [k for k in act]

                # 計算累計的句子
                ext_collections += [sum(ext_arts[ext_inds[-1][0]:ext_inds[-1][0]+k+1],[]) for k in range(ext_inds[-1][1])]

                abs_collections += [sum(abs_sents[:k+1],[]) if k<len(abs_sents) 
                                        else sum(abs_sents[0:len(abs_sents)],[]) 
                                        for k in range(ext_inds[-1][1])]

            if beam_size > 1: # do n times abstract
                all_beams = abstractor(ext_arts, beam_size, diverse)
                dec_outs = rerank_mp(all_beams, ext_inds)

                dec_collections = [[sum(dec_outs[pos[0]:pos[0]+k+1],[]) for k in range(pos[1])] for pos in ext_inds]
                dec_collections = [x for sublist in dec_collections for x in sublist]
                for index, chooser in enumerate(ext_acts):
                    if chooser == 0:
                        dec_outs_act += [dec_outs[index]]
                    else:
                        dec_outs_act += [ext_arts[index]]

                assert len(ext_collections)==len(dec_collections)==len(abs_collections)
                for ext, dec, abss, act in zip(ext_collections, dec_collections, abs_collections, ext_acts):
                    # for each sent in extracted digest
                    # All abstract mapping
                    rouge_before_rewriten = compute_rouge_n(ext, abss, n=1)
                    rouge_after_rewriten = compute_rouge_n(dec, abss, n=1)
                    diff_ins = rouge_before_rewriten - rouge_after_rewriten
                    rewrite_less_rouge.append(diff_ins)
            
            else: # do 1st abstract
                dec_outs = abstractor(ext_arts)
                dec_collections = [[sum(dec_outs[pos[0]:pos[0]+k+1],[]) for k in range(pos[1])] for pos in ext_inds]
                dec_collections = [x for sublist in dec_collections for x in sublist]
                for index, chooser in enumerate(ext_acts):
                    if chooser == 0:
                        dec_outs_act += [dec_outs[index]]
                    else:
                        dec_outs_act += [ext_arts[index]]
                # dec_outs_act = dec_outs
                # dec_outs_act = ext_arts
                assert len(ext_collections)==len(dec_collections)==len(abs_collections)
                for ext, dec, abss, act in zip(ext_collections, dec_collections, abs_collections, ext_acts):
                    # for each sent in extracted digest
                    # All abstract mapping
                    rouge_before_rewriten = compute_rouge_n(ext, abss, n=1)
                    rouge_after_rewriten = compute_rouge_n(dec, abss, n=1)
                    diff_ins = rouge_before_rewriten - rouge_after_rewriten
                    rewrite_less_rouge.append(diff_ins)

            assert i == batch_size*i_debug

            for iters, (j, n) in enumerate(ext_inds):        
                
                do_right_rewrite = sum([1 for rouge, action in zip(rewrite_less_rouge[j:j+n], ext_acts[j:j+n]) if rouge<0 and action==0])
                do_right_preserve = sum([1 for rouge, action in zip(rewrite_less_rouge[j:j+n], ext_acts[j:j+n]) if rouge>=0 and action==1])
                
                decoded_words_nums = [len(dec) for dec in dec_outs_act[j:j+n]]
                ext_words_nums = [token_nums_batch[iters][x] for x in range(len(token_nums_batch[iters])) if x in ext_nums[iters]]

                # 皆取extracted label 
                # decoded_sents = [raw_article_batch[iters][x] for x in extracted_batch[iters]]         
                # 統計數據 [START]
                decoded_sents = [' '.join(dec) for dec in dec_outs_act[j:j+n]]
                rouge_score1 = compute_rouge_n(' '.join(decoded_sents),' '.join(raw_abstract_batch[iters]),n=1)
                rouge_score2 = compute_rouge_n(' '.join(decoded_sents),' '.join(raw_abstract_batch[iters]),n=2)
                rouge_scorel = compute_rouge_l(' '.join(decoded_sents),' '.join(raw_abstract_batch[iters]))
                
                dec_sent_nums = len(decoded_sents)
                abs_sent_nums = len(raw_abstract_batch[iters])
                doc_sent_nums = len(raw_article_batch[iters])
                
                doc_words_nums = sum(token_nums_batch[iters])
                ext_words_nums = sum(ext_words_nums)
                abs_words_nums = sum(decoded_words_nums)

                label_recall = len(set(ext_nums[iters]) & set(extracted_batch[iters])) / len(extracted_batch[iters])
                label_precision = len(set(ext_nums[iters]) & set(extracted_batch[iters])) / len(ext_nums[iters])
                less_rewrite = rewrite_less_rouge[j+n-1]
                dec_one_action_num = sum(ext_acts[j:j+n])
                dec_zero_action_num = n - dec_one_action_num

                ext_indices = '_'.join([str(i) for i in ext_nums[iters]])
                
                top3 = set([0,1,2]) <= set(ext_nums[iters])
                top3_gold = set([0,1,2]) <= set(extracted_batch[iters])
                
                # Any Top 2 
                top2 = set([0,1]) <= set(ext_nums[iters]) or set([1,2]) <= set(ext_nums[iters]) or set([0,2]) <= set(ext_nums[iters])
                top2_gold = set([0,1]) <= set(extracted_batch[iters]) or set([1,2]) <= set(extracted_batch[iters]) or set([0,2]) <= set(extracted_batch[iters])
                
                with open(join(save_path,'_statisticsDecode.log.csv'),'a') as w:
                    w.write('{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n'.format(i,rouge_score1,
                     rouge_score2, rouge_scorel, dec_sent_nums,
                      abs_sent_nums, doc_sent_nums, doc_words_nums,
                      ext_words_nums,abs_words_nums,(ext_words_nums - abs_words_nums),
                      label_recall, label_precision,
                      less_rewrite, dec_one_action_num, dec_zero_action_num, 
                      ext_indices, top3, top3_gold, top2, top2_gold,do_right_rewrite,do_right_preserve))
                # 統計數據 END

                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    decoded_sents = [i for i in decoded_sents if i!='']
                    if len(decoded_sents) > 0:
                        f.write(make_html_safe('\n'.join(decoded_sents)))
                    else:
                        f.write('')

                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i/n_data*100,
                    timedelta(seconds=int(time()-start))
                ), end='')
            
    print()
def decode(save_path, model_dir, split, batch_size,
           beam_size, diverse, max_len, cuda):
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'),
                                    max_len, cuda)
        else:
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'),
                                        max_len, cuda)
    extractor = RLExtractor(model_dir, cuda=cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles
    dataset = DecodeDataset(split)

    n_data = len(dataset)
    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=4,
        collate_fn=coll
    )

    # prepare save paths and logs
    os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            tokenized_article_batch = map(tokenize(None), raw_article_batch)
            ext_arts = []
            ext_inds = []
            for raw_art_sents in tokenized_article_batch:
                ext = extractor(raw_art_sents)[:-1]  # exclude EOE
                if not ext:
                    # use top-5 if nothing is extracted
                    # in some rare cases rnn-ext does not extract at all
                    ext = list(range(5))[:len(raw_art_sents)]
                else:
                    ext = [i.item() for i in ext]
                ext_inds += [(len(ext_arts), len(ext))]
                ext_arts += [raw_art_sents[i] for i in ext]
            if beam_size > 1:
                all_beams = abstractor(ext_arts, beam_size, diverse)
                dec_outs = rerank_mp(all_beams, ext_inds)
            else:
                dec_outs = abstractor(ext_arts)
            assert i == batch_size*i_debug
            for j, n in ext_inds:
                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j+n]]
                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    f.write(make_html_safe('\n'.join(decoded_sents)))
                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i/n_data*100,
                    timedelta(seconds=int(time()-start))
                ), end='')
    print()
Exemple #13
0
def decode(args, predict=False):
    # save_path = args.path
    batch_size = args.batch
    beam_size = args.beam
    diverse = args.div
    start = time()
    extractor = args.extractor
    abstractor = args.abstractor
    # setup model
    text = ''

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles

    if not predict:
        dataset = DecodeDataset(args)

        n_data = len(dataset)
        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=4,
                            collate_fn=coll)
    else:
        n_data = 1
        loader = clean_and_split(args.text)
        loader = [[[' '.join(mecab_tokenizer(line)) for line in loader]]]
        text = '\n'.join(loader[0][0])

    i = 0
    #print(text)
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            tokenized_article_batch = map(tokenize(None), raw_article_batch)
            ext_arts = []
            ext_inds = []
            for raw_art_sents in tokenized_article_batch:
                ext = extractor(raw_art_sents)[:-1]  # exclude EOE
                if not ext:
                    # use top-5 if nothing is extracted
                    # in some rare cases rnn-ext does not extract at all
                    ext = list(range(5))[:len(raw_art_sents)]
                else:
                    ext = [i.item() for i in ext]
                ext_inds += [(len(ext_arts), len(ext))]
                ext_arts += [raw_art_sents[i] for i in ext]
            if beam_size > 1:
                #print(ext_arts)
                all_beams = abstractor(ext_arts, beam_size, diverse)
                dec_outs = rerank_mp(all_beams, ext_inds)
            else:
                dec_outs = abstractor(ext_arts)
            assert i == batch_size * i_debug
            source_text = [''.join(sent) for sent in ext_arts]
            for j, n in ext_inds:
                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]]
                decoded_sents = decoded_sents[:20]
                # with open(join(save_path, 'output/{}.dec'.format(i)),
                #           'w') as f:
                #     f.write(make_html_safe('\n'.join(decoded_sents)))
                result = make_html_safe('\n\n'.join(decoded_sents))
                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i / n_data * 100,
                    timedelta(seconds=int(time() - start))),
                      end='')
    print()
    return text, result, source_text
def decode(save_path, abs_dir, split, batch_size, max_len, cuda, min_len):
    start = time()
    # setup model
    if abs_dir is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        raise Exception('abs directory none!')
    else:
        #abstractor = Abstractor(abs_dir, max_len, cuda)
        abstractor = BeamAbstractor(abs_dir,
                                    max_len,
                                    cuda,
                                    min_len,
                                    reverse=args.reverse)

    bert = abstractor._bert
    if bert:
        tokenizer = abstractor._tokenizer
    if bert:
        import logging
        logging.basicConfig(level=logging.ERROR)

    # if args.docgraph or args.paragraph:
    #     abstractor = BeamAbstractorGAT(abs_dir, max_len, cuda, min_len, reverse=args.reverse)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles

    dataset = AbsDecodeDataset(split)

    n_data = len(dataset)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=coll)
    save_path = os.path.join(save_path, split)
    os.makedirs(save_path)
    # prepare save paths and logs
    dec_log = {}
    dec_log['abstractor'] = (None if abs_dir is None else json.load(
        open(join(abs_dir, 'meta.json'))))
    dec_log['rl'] = False
    dec_log['split'] = split
    dec_log['beam'] = 5  # greedy decoding only
    beam_size = 5
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)
    os.makedirs(join(save_path, 'output'))

    # Decoding
    i = 0
    length = 0
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            if bert:
                tokenized_article_batch = map(
                    tokenize_keepcase(args.max_input), raw_article_batch)
            else:
                tokenized_article_batch = map(tokenize(args.max_input),
                                              raw_article_batch)
            ext_arts = []
            ext_inds = []
            beam_inds = []
            pre_abs = list(tokenized_article_batch)
            pre_abs = [article[0] for article in pre_abs]
            for j in range(len(pre_abs)):
                beam_inds += [(len(beam_inds), 1)]
            all_beams = abstractor(pre_abs, beam_size, diverse=1.0)
            dec_outs = rerank_mp(all_beams, beam_inds)

            for dec_out in dec_outs:
                if bert:
                    text = ''.join(' '.join(dec_out).split(' '))
                    dec_out = bytearray([
                        tokenizer.byte_decoder[c] for c in text
                    ]).decode('utf-8', errors=tokenizer.errors)
                    dec_out = [dec_out]

                dec_out = sent_tokenize(' '.join(dec_out))
                ext = [sent.split(' ') for sent in dec_out]
                ext_inds += [(len(ext_arts), len(ext))]
                ext_arts += ext
            dec_outs = ext_arts

            assert i == batch_size * i_debug
            for j, n in ext_inds:
                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]]
                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    f.write(make_html_safe('\n'.join(decoded_sents)))
                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i / n_data * 100,
                    timedelta(seconds=int(time() - start))),
                      end='')
                length += len(decoded_sents)
        print('average summary length:', length / i)
def decode_entity(save_path, model_dir, split, batch_size,
           beam_size, diverse, max_len, cuda, sc, min_len):
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
    #if not meta['net_args'].__contains__('abstractor'):
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'),
                                    max_len, cuda)
        else:
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'),
                                        max_len, cuda, min_len=min_len)

    if sc:
        extractor = SCExtractor(model_dir, cuda=cuda, entity=True)
    else:
        extractor = RLExtractor(model_dir, cuda=cuda)

    # setup loader
    def coll(batch):
        batch = list(filter(bool, batch))
        return batch

    if args.key == 1:
        key = 'filtered_rule1_input_mention_cluster'
    elif args.key == 2:
        key = 'filtered_rule23_6_input_mention_cluster'
    else:
        raise Exception
    dataset = DecodeDatasetEntity(split, key)

    n_data = len(dataset)
    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=4,
        collate_fn=coll
    )

    # prepare save paths and logs
    os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    if sc:
        i = 0
        length = 0
        sent_selected = 0
        with torch.no_grad():
            for i_debug, raw_input_batch in enumerate(loader):
                raw_article_batch, clusters = zip(*raw_input_batch)
                tokenized_article_batch = map(tokenize(None), raw_article_batch)
                #processed_clusters = map(preproc(list(tokenized_article_batch), clusters))
                #processed_clusters = list(zip(*processed_clusters))
                ext_arts = []
                ext_inds = []
                pre_abs = []
                beam_inds = []
                for raw_art_sents, raw_cls in zip(tokenized_article_batch, clusters):
                    processed_clusters = preproc(raw_art_sents, raw_cls)
                    ext = extractor((raw_art_sents, processed_clusters))[:]  # exclude EOE
                    sent_selected += len(ext)
                    if not ext:
                        # use top-3 if nothing is extracted
                        # in some rare cases rnn-ext does not extract at all
                        ext = list(range(3))[:len(raw_art_sents)]
                    else:
                        ext = [i for i in ext]
                    ext_art = list(map(lambda i: raw_art_sents[i], ext))
                    pre_abs.append([word for sent in ext_art for word in sent])
                    beam_inds += [(len(beam_inds), 1)]

                if beam_size > 1:
                    # all_beams = abstractor(ext_arts, beam_size, diverse)
                    # dec_outs = rerank_mp(all_beams, ext_inds)
                    all_beams = abstractor(pre_abs, beam_size, diverse=1.0)
                    dec_outs = rerank_mp(all_beams, beam_inds)
                else:
                    dec_outs = abstractor(pre_abs)
                for dec_out in dec_outs:
                    dec_out = sent_tokenize(' '.join(dec_out))
                    ext = [sent.split(' ') for sent in dec_out]
                    ext_inds += [(len(ext_arts), len(ext))]
                    ext_arts += ext

                dec_outs = ext_arts
                assert i == batch_size*i_debug
                for j, n in ext_inds:
                    decoded_sents = [' '.join(dec) for dec in dec_outs[j:j+n]]
                    with open(join(save_path, 'output/{}.dec'.format(i)),
                              'w') as f:
                        f.write(make_html_safe('\n'.join(decoded_sents)))
                    i += 1
                    print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                        i, n_data, i/n_data*100,
                        timedelta(seconds=int(time()-start))
                    ), end='')
                    length += len(decoded_sents)
    else:
        i = 0
        length = 0
        with torch.no_grad():
            for i_debug, raw_article_batch in enumerate(loader):
                tokenized_article_batch = map(tokenize(None), raw_article_batch)
                ext_arts = []
                ext_inds = []
                for raw_art_sents in tokenized_article_batch:
                    ext = extractor(raw_art_sents)[:-1]  # exclude EOE
                    if not ext:
                        # use top-5 if nothing is extracted
                        # in some rare cases rnn-ext does not extract at all
                        ext = list(range(5))[:len(raw_art_sents)]
                    else:
                        ext = [i.item() for i in ext]
                    ext_inds += [(len(ext_arts), len(ext))]
                    ext_arts += [raw_art_sents[i] for i in ext]
                if beam_size > 1:
                    all_beams = abstractor(ext_arts, beam_size, diverse)
                    dec_outs = rerank_mp(all_beams, ext_inds)
                else:
                    dec_outs = abstractor(ext_arts)
                assert i == batch_size*i_debug
                for j, n in ext_inds:
                    decoded_sents = [' '.join(dec) for dec in dec_outs[j:j+n]]
                    with open(join(save_path, 'output/{}.dec'.format(i)),
                              'w') as f:
                        f.write(make_html_safe('\n'.join(decoded_sents)))
                    i += 1
                    print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                        i, n_data, i/n_data*100,
                        timedelta(seconds=int(time()-start))
                    ), end='')
                    length += len(decoded_sents)
    print('average summary length:', length / i)
    print('average sentence selected:', sent_selected)
def decode(save_path, model_dir, split, batch_size,
           beam_size, diverse, max_len, cuda):
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = lambda x,y:x
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'),
                                    max_len, cuda)
        else:
            print('BEAM')
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'),
                                        max_len, cuda)
    extractor = RLExtractor(model_dir, cuda=cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles
    dataset = DecodeDataset(split)

    n_data = len(dataset)
    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=4,
        collate_fn=coll
    )

    # prepare save paths and logs
    try:
        os.makedirs(join(save_path, 'output'))
    except:
        pass
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    total_leng = 0
    total_num = 0
    with torch.no_grad():
        for i_debug, data_batch in enumerate(loader):
            raw_article_batch, sent_label_batch = tuple(map(list, unzip(data_batch)))
            tokenized_article_batch = map(tokenize(None), raw_article_batch)
            #ext_arts = []
            ext_inds = []
            dirty = []
            ext_sents = []
            masks = []
            for raw_art_sents, sent_labels in zip(tokenized_article_batch, sent_label_batch):
                ext = extractor(raw_art_sents, sent_labels)  # exclude EOE

                tmp_size = min(max_dec_edu, len(ext) - 1)
                #total_leng += sum([len(e) -1 for e in ext[:-1]])
                #total_num += len(ext) - 1
                #print(tmp_size, len(ext) - 1)
                ext_inds += [(len(ext_sents), tmp_size)]
                tmp_stop = ext[-1][-1].item()
                tmp_truncate = tmp_stop - 1
                str_arts = list(map(lambda x: ' '.join(x), raw_art_sents))
                for idx in ext[:tmp_size]:
                    t, m = rl_edu_to_sentence(str_arts, idx)
                    total_leng += len(t)
                    total_num += 1
                    assert len(t) == len(m)
                    if t == []:
                        assert len(idx) == 1
                        id = idx[0].item()
                        if id == tmp_truncate:
                            dirty.append(len(ext_sents))
                            ext_sents.append(label)
                            masks.append(label_mask)
                    else:
                        if idx[-1].item() != tmp_stop:
                            ext_sents.append(t)
                            masks.append(m)


                #ext_arts += [raw_art_sents[i] for i in ext]
            #print(ext_sents)
            #print(masks)
            #print(dirty)
            #exit(0)
            if beam_size > 1:
                #print(ext_sents)
                #print(masks)
                all_beams = abstractor(ext_sents, masks, beam_size, diverse)
                print('rerank')
                dec_outs = rerank_mp(all_beams, ext_inds)
                for d in dirty:
                    dec_outs[d] = []
                # TODO:!!!!!!!!!!!
            else:
                dec_outs = abstractor(ext_sents, masks)
                for d in dirty:
                    dec_outs[d] = []
            assert i == batch_size*i_debug
            for j, n in ext_inds:
                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j+n]]
                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    f.write(make_html_safe('\n'.join(decoded_sents)))
                if i % 100 == 0:
                    print(total_leng / total_num)
                i += 1

                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i/n_data*100,
                    timedelta(seconds=int(time()-start))
                ), end='')
    print()
def decode(save_path, model_dir, split, batch_size, beam_size, diverse,
           max_len, cuda):
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    """
    print('Docoding extraction result....')
    abstractor = identity
    """
    if meta['net_args']['abstractor'] is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            print('Decoding full model result with 1 beamsize')
            abstractor = Abstractor(join(model_dir, 'abstractor'), max_len,
                                    cuda)
        else:
            print('Decoding full model result with {} beamsize'.format(
                beam_size))
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len,
                                        cuda)

    extractor = RLExtractor(model_dir, cuda=cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles

    dataset = DecodeDataset(split)

    n_data = len(dataset)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=coll)

    # prepare save paths and logs
    os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    filename = 'extracted_index.txt'
    with torch.no_grad():
        #f = open(os.path.join(save_path,filename), 'w')
        #f.close()
        for i_debug, batch_data in enumerate(loader):
            #tokenized_article_batch = map(tokenize(None), raw_article_batch)
            batch_data = map(tokenize_decode(None), batch_data)
            ext_arts = []
            ext_inds = []
            for data in batch_data:
                raw_art_sents, topic = data
                ext = extractor(raw_art_sents, topic)[:-1]  # exclude EOE
                if not ext:
                    ## use top-5 if nothing is extracted
                    ## in some rare cases rnn-ext does not extract at all
                    ext = list(range(1))[:len(
                        raw_art_sents)]  # YUNZHU change from 5 to 1
                    ## if want the extractor result   !!!  ######################
                    #with open(os.path.join(save_path, filename), 'a') as f:
                    #    line = [str(i) for i in ext]+['\n']
                    #    f.writelines(line)
                    #print(i)
                    ############################################################
                else:
                    ext = [i.item() for i in ext]
                    ## if want the extractor result   !!!  #####################
                    #with open('save_decode_extract/'+filename, 'a') as f:
                    #    line = [str(i) for i in ext]+['\n']
                    #    f.writelines(line)
                    #print(i)

#############################################################
#pdb.set_trace()
#i+=1

                ext_inds += [(len(ext_arts), len(ext))]
                ext_arts += [raw_art_sents[i] for i in ext]

            if beam_size > 1:
                all_beams = abstractor(ext_arts, beam_size, diverse)
                dec_outs = rerank_mp(all_beams, ext_inds)
            else:
                dec_outs = abstractor(ext_arts)
            assert i == batch_size * i_debug

            for j, n in ext_inds:

                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]]
                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    f.write(make_html_safe('\n'.join(decoded_sents)))
                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i / n_data * 100,
                    timedelta(seconds=int(time() - start))),
                      end='')

    print()
def decode(save_path, save_file, model_dir, split, batch_size, beam_size,
           diverse, max_len, cuda):
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'), max_len,
                                    cuda)
        else:
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len,
                                        cuda)
    extractor = RLExtractor(model_dir, cuda=cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles

    dataset = DecodeDataset(split)

    n_data = len(dataset)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=coll)

    # prepare save paths and logs
    os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            tokenized_article_batch = map(tokenize(None), raw_article_batch)
            ext_arts = []
            ext_inds = []
            for raw_art_sents in tokenized_article_batch:
                ext = extractor(raw_art_sents)[:-1]  # exclude EOE
                if not ext:
                    # use top-5 if nothing is extracted
                    # in some rare cases rnn-ext does not extract at all
                    ext = list(range(5))[:len(raw_art_sents)]
                else:
                    ext = [i.item() for i in ext]
                ext_inds += [(len(ext_arts), len(ext))]
                ext_arts += [raw_art_sents[i] for i in ext]
            if beam_size > 1:
                all_beams = abstractor(ext_arts, beam_size, diverse)
                dec_outs = rerank_mp(all_beams, ext_inds)
            else:
                dec_outs = abstractor(ext_arts)
            assert i == batch_size * i_debug
            for j, n in ext_inds:
                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]]
                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    f.write(make_html_safe('\n'.join(decoded_sents)))
                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i / n_data * 100,
                    timedelta(seconds=int(time() - start))),
                      end='')
    print()

    #not needed for cnn/dailymail dataset probably
    f = open(save_file, "w")
    summaries_files = os.listdir(join(save_path, 'output'))
    n = len(summaries_files)
    summaries_list = [""] * n

    for fname in summaries_files:
        num = int(fname.replace(".dec", ""))
        f_local = open(join(save_path, "output", fname))
        summaries_list[num] = f_local.read().replace("\n", " ")
        f_local.close()

    assert (len(summaries_list) == n)

    f.write("\n".join(summaries_list))
    f.close()
Exemple #19
0
def decode(save_path,
           model_dir,
           split,
           batch_size,
           beam_size,
           diverse,
           max_len,
           cuda,
           bart=False,
           clip=-1,
           tri_block=False):
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'), max_len,
                                    cuda)
        else:
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len,
                                        cuda)
    extractor = RLExtractor(model_dir, cuda=cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles

    dataset = DecodeDataset(split)

    n_data = len(dataset)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=coll)

    # prepare save paths and logs
    os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            # raw_article_batch
            tokenized_article_batch = map(tokenize(None),
                                          [r[0] for r in raw_article_batch])
            tokenized_abs_batch = map(tokenize(None),
                                      [r[1] for r in raw_article_batch])
            ext_arts = []
            ext_inds = []
            for raw_art_sents, raw_abs_sents in zip(tokenized_article_batch,
                                                    tokenized_abs_batch):
                ext, raw_art_sents = extractor(raw_art_sents,
                                               raw_abs_sents=raw_abs_sents)
                # print(raw_art_sen/ts)
                ext = ext[:-1]  # exclude EOE
                # print(ext)
                if tri_block:
                    _pred = []
                    _ids = []
                    for j in ext:
                        if (j >= len(raw_art_sents)):
                            continue
                        candidate = " ".join(raw_art_sents[j]).strip()
                        if (not _block_tri(candidate, _pred)):
                            _pred.append(candidate)
                            _ids.append(j)
                        else:
                            continue

                        if (len(_pred) == 3):
                            break
                    ext = _ids
                    # print(_pred)
                if clip > 0 and len(
                        ext) > clip:  #ADDED FOR CLIPPING, CHANGE BACK
                    # print("hi", clip)
                    ext = ext[0:clip]
                if not ext:
                    # use top-5 if nothing is extracted
                    # in some rare cases rnn-ext does not extract at all
                    ext = list(range(5))[:len(raw_art_sents)]
                else:
                    ext = [i.item() for i in ext]
                    # print(ext)
                ext_inds += [(len(ext_arts), len(ext))]
                ext_arts += [raw_art_sents[i] for i in ext]
            if bart:
                # print("hi")
                dec_outs = get_bart_summaries(ext_arts,
                                              tokenizer,
                                              bart_model,
                                              beam_size=beam_size)
            else:
                if beam_size > 1:
                    all_beams = abstractor(ext_arts, beam_size, diverse)
                    dec_outs = rerank_mp(all_beams, ext_inds)
                else:
                    dec_outs = abstractor(ext_arts)
            # print(dec_outs, i, i_debug)
            assert i == batch_size * i_debug
            for j, n in ext_inds:
                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]]
                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    f.write(make_html_safe('\n'.join(decoded_sents)))
                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i / n_data * 100,
                    timedelta(seconds=int(time() - start))),
                      end='')
    print()
Exemple #20
0
def decode(args):
    save_path = args.path
    model_dir = args.model_dir
    batch_size = args.batch
    beam_size = args.beam
    diverse = args.div
    max_len = args.max_dec_word
    cuda = args.cuda
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'), max_len,
                                    cuda)
        else:
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len,
                                        cuda)
    extractor = RLExtractor(model_dir, cuda=cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles

    dataset = DecodeDataset(args)

    n_data = len(dataset)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=coll)

    # prepare save paths and logs
    if not os.path.exists(join(save_path, 'output')):
        os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = args.mode
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            tokenized_article_batch = map(tokenize(None), raw_article_batch)
            ext_arts = []
            ext_inds = []
            for raw_art_sents in tokenized_article_batch:
                ext = extractor(raw_art_sents)[:-1]  # exclude EOE
                if not ext:
                    # use top-5 if nothing is extracted
                    # in some rare cases rnn-ext does not extract at all
                    ext = list(range(5))[:len(raw_art_sents)]
                else:
                    ext = [i.item() for i in ext]
                ext_inds += [(len(ext_arts), len(ext))]
                ext_arts += [raw_art_sents[i] for i in ext]
            if beam_size > 1:
                all_beams = abstractor(ext_arts, beam_size, diverse)
                dec_outs = rerank_mp(all_beams, ext_inds)
            else:
                dec_outs = abstractor(ext_arts)
            assert i == batch_size * i_debug
            for j, n in ext_inds:
                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]]
                decoded_sents = decoded_sents[:20]
                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    f.write(make_html_safe('\n'.join(decoded_sents)))
                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i / n_data * 100,
                    timedelta(seconds=int(time() - start))),
                      end='')
    print()
def decode(save_path, model_dir, split, batch_size, beam_size, diverse,
           max_len, cuda, sc, min_len):
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
        #if not meta['net_args'].__contains__('abstractor'):
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'), max_len,
                                    cuda)
        else:
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len,
                                        cuda, min_len)

    if sc:
        extractor = SCExtractor(model_dir, cuda=cuda)
    else:
        extractor = RLExtractor(model_dir, cuda=cuda)

    #check if use bert

    try:
        _bert = extractor._net._bert
    except:
        _bert = False
        print('no bert arg:')

    if _bert:
        tokenizer = BertTokenizer.from_pretrained(
            'bert-large-uncased-whole-word-masking')
        print('bert tokenizer loaded')

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles

    dataset = DecodeDataset(split)

    n_data = len(dataset)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=coll)

    # prepare save paths and logs
    os.makedirs(join(save_path, 'output'))
    dec_log = {}
    dec_log['abstractor'] = meta['net_args']['abstractor']
    dec_log['extractor'] = meta['net_args']['extractor']
    dec_log['rl'] = True
    dec_log['split'] = split
    dec_log['beam'] = beam_size
    dec_log['diverse'] = diverse
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    if sc:
        i = 0
        length = 0
        with torch.no_grad():
            for i_debug, raw_article_batch in enumerate(loader):
                tokenized_article_batch = map(tokenize(None),
                                              raw_article_batch)
                ext_arts = []
                ext_inds = []
                if _bert:
                    for raw_art_sents, raw_art in zip(tokenized_article_batch,
                                                      raw_article_batch):
                        tokenized_sents = [
                            tokenizer.tokenize(source_sent.lower())
                            for source_sent in raw_art
                        ]
                        tokenized_sents = [
                            tokenized_sent + ['[SEP]']
                            for tokenized_sent in tokenized_sents
                        ]
                        tokenized_sents[0] = ['[CLS]'] + tokenized_sents[0]
                        word_num = [
                            len(tokenized_sent)
                            for tokenized_sent in tokenized_sents
                        ]
                        truncated_word_num = []
                        total_count = 0
                        for num in word_num:
                            if total_count + num < MAX_LEN_BERT:
                                truncated_word_num.append(num)
                            else:
                                truncated_word_num.append(MAX_LEN_BERT -
                                                          total_count)
                                break
                            total_count += num
                        tokenized_sents = list(
                            concat(tokenized_sents))[:MAX_LEN_BERT]
                        tokenized_sents = tokenizer.convert_tokens_to_ids(
                            tokenized_sents)
                        art_sents = tokenize(None, raw_art)
                        _input = (art_sents, tokenized_sents,
                                  truncated_word_num)

                        ext = extractor(_input)[:]  # exclude EOE
                        if not ext:
                            # use top-3 if nothing is extracted
                            # in some rare cases rnn-ext does not extract at all
                            ext = list(range(3))[:len(raw_art_sents)]
                        else:
                            ext = [i for i in ext]
                        ext_inds += [(len(ext_arts), len(ext))]
                        ext_arts += [raw_art_sents[i] for i in ext]
                else:
                    for raw_art_sents in tokenized_article_batch:
                        ext = extractor(raw_art_sents)[:]  # exclude EOE
                        if not ext:
                            # use top-5 if nothing is extracted
                            # in some rare cases rnn-ext does not extract at all
                            ext = list(range(5))[:len(raw_art_sents)]
                        else:
                            ext = [i for i in ext]
                        ext_inds += [(len(ext_arts), len(ext))]
                        ext_arts += [raw_art_sents[i] for i in ext]
                if beam_size > 1:
                    all_beams = abstractor(ext_arts, beam_size, diverse)
                    dec_outs = rerank_mp(all_beams, ext_inds)
                else:
                    dec_outs = abstractor(ext_arts)
                assert i == batch_size * i_debug
                for j, n in ext_inds:
                    decoded_sents = [
                        ' '.join(dec) for dec in dec_outs[j:j + n]
                    ]
                    with open(join(save_path, 'output/{}.dec'.format(i)),
                              'w') as f:
                        f.write(make_html_safe('\n'.join(decoded_sents)))
                    i += 1
                    print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                        i, n_data, i / n_data * 100,
                        timedelta(seconds=int(time() - start))),
                          end='')
                    length += len(decoded_sents)
    else:
        i = 0
        length = 0
        with torch.no_grad():
            for i_debug, raw_article_batch in enumerate(loader):
                tokenized_article_batch = map(tokenize(None),
                                              raw_article_batch)
                ext_arts = []
                ext_inds = []
                for raw_art_sents in tokenized_article_batch:
                    ext = extractor(raw_art_sents)[:-1]  # exclude EOE
                    if not ext:
                        # use top-5 if nothing is extracted
                        # in some rare cases rnn-ext does not extract at all
                        ext = list(range(5))[:len(raw_art_sents)]
                    else:
                        ext = [i.item() for i in ext]
                    ext_inds += [(len(ext_arts), len(ext))]
                    ext_arts += [raw_art_sents[i] for i in ext]
                if beam_size > 1:
                    all_beams = abstractor(ext_arts, beam_size, diverse)
                    dec_outs = rerank_mp(all_beams, ext_inds)
                else:
                    dec_outs = abstractor(ext_arts)
                assert i == batch_size * i_debug
                for j, n in ext_inds:
                    decoded_sents = [
                        ' '.join(dec) for dec in dec_outs[j:j + n]
                    ]
                    with open(join(save_path, 'output/{}.dec'.format(i)),
                              'w') as f:
                        f.write(make_html_safe('\n'.join(decoded_sents)))
                    i += 1
                    print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                        i, n_data, i / n_data * 100,
                        timedelta(seconds=int(time() - start))),
                          end='')
                    length += len(decoded_sents)
    print('average summary length:', length / i)