Esempio n. 1
0
def decode(raw_sentences,
           extractor,
           abstractor,
           beam_size,
           diverse=1,
           token_threshold=5,
           postpro=False,
           debug=False):
    with torch.no_grad():
        start = time()
        # setup model

        tokenized_sentences = tokenize(None, raw_sentences)
        ext = extractor(tokenized_sentences)[:-1]  # exclude EOE
        if not ext:
            # use top-5 if nothing is extracted
            # in some rare cases rnn-ext does not extract at all
            ext = list(range(5))[:len(tokenized_sentences)]
        else:
            ext = [i.item() for i in ext]
        ext_sentences = [tokenized_sentences[i] for i in ext]

        if beam_size > 1:
            all_beams = abstractor(ext_sentences, beam_size, diverse)
            dec_outs = rerank_mp(all_beams, [(0, len(ext_sentences))],
                                 debug=debug)
        else:
            dec_outs = abstractor(ext_sentences, debug=debug)

        attns = None
        if debug:
            dec_outs, attns = dec_outs
            attns = [[t[:len(ext_sentences[i])] for t in attn]
                     for i, attn in enumerate(attns)]
            source_col_name = 'source'
            attns = [
                pd.DataFrame({
                    source_col_name: ext_sentences[i],
                    **{dec_outs[i][j]: t
                       for j, t in enumerate(attn)},
                }).set_index(source_col_name) for i, attn in enumerate(attns)
            ]

        if postpro:
            decoded_sentences = postprocess(dec_outs, token_threshold)
        else:
            decoded_sentences = [' '.join(dec) for dec in dec_outs]

        logging.info('decoded {} sentences in {:.3f}s'.format(
            len(raw_sentences),
            time() - start))

        if debug:
            return (ext, decoded_sentences), attns

        return ext, decoded_sentences
Esempio n. 2
0
def decode(args, split):
    decode_path = join(args.path, 'abs_decode')
    os.makedirs(decode_path, exist_ok=True)

    dataset = CnnDmDataset(split, DATA_DIR)
    print('Generating abstracts for {} dataset'.format(split))
    for i in tqdm(range(len(dataset))):
        file_path = join(decode_path, '{}.dec'.format(i))
        if not args.force and os.path.exists(file_path):
            continue

        js_data = dataset[i]
        art_sents, extracts = (js_data['article'], js_data['extracted'])

        if not extracts:
            abs_results = []
        else:
            abs_results = abstract_callback(args, tokenize(None, (art_sents[i] for i in extracts)))

        with open(file_path, 'w', encoding='utf8') as f:
            f.write('\n'.join(abs_results))
Esempio n. 3
0
 def read_file(file_dir, file_name):
     with open(join(file_dir, file_name)) as f:
         return tokenize(None, f.read().splitlines())
Esempio n. 4
0
 def read_data_file(file_dir, file_name):
     with open(join(file_dir, file_name)) as f:
         return tokenize(None, json.load(f)['article'])
Esempio n. 5
0
 def coll(batch):
     art_batch, abs_batch = unzip(batch)
     art_sents = list(filter(bool, map(tokenize(None), art_batch)))
     abs_sents = list(filter(bool, map(tokenize(None), abs_batch)))
     return art_sents, abs_sents
Esempio n. 6
0
def decode(save_path, abs_dir, ext_dir, split, batch_size, max_len, cuda):
    start = time()
    # setup model
    if abs_dir is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        abstractor = identity
    else:
        abstractor = Abstractor(abs_dir, max_len, cuda)
    if ext_dir is None:
        # NOTE: if no abstractor is provided then
        #       it would be  the lead-N extractor
        extractor = lambda art_sents: list(range(len(art_sents)))[:MAX_ABS_NUM]
    else:
        extractor = Extractor(ext_dir, max_ext=MAX_ABS_NUM, cuda=cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles

    dataset = DecodeDataset(split, DATASET_DIR)

    n_data = len(dataset)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=coll)

    # prepare save paths and logs
    for i in range(MAX_ABS_NUM):
        os.makedirs(join(save_path, 'output_{}'.format(i)))
    dec_log = {
        'abstractor': (None if abs_dir is None else json.load(
            open(join(abs_dir, 'meta.json')))),
        'extractor': (None if ext_dir is None else json.load(
            open(join(ext_dir, 'meta.json')))),
        'rl':
        False,
        'split':
        split,
        'beam':
        1
    }
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            tokenized_article_batch = map(tokenize(None), raw_article_batch)
            ext_arts = []
            ext_inds = []
            for raw_art_sents in tokenized_article_batch:
                ext = extractor(raw_art_sents)
                ext_inds += [(len(ext_arts), len(ext))]
                ext_arts += list(map(lambda i: raw_art_sents[i], ext))
            dec_outs = abstractor(ext_arts)
            assert i == batch_size * i_debug
            for j, n in ext_inds:
                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]]
                for k, dec_str in enumerate(decoded_sents):
                    with open(join(save_path, 'output_{}/{}.dec'.format(k, i)),
                              'w') as f:
                        f.write(make_html_safe(dec_str))

                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i / n_data * 100,
                    timedelta(seconds=int(time() - start))),
                      end='')
    print()
Esempio n. 7
0
def decode(save_path, model_dir, split, batch_size, beam_size, diverse,
           max_len, cuda):
    start = time()
    # setup model
    with open(join(model_dir, 'meta.json')) as f:
        meta = json.loads(f.read())
    if meta['net_args']['abstractor'] is None:
        # NOTE: if no abstractor is provided then
        #       the whole model would be extractive summarization
        assert beam_size == 1
        abstractor = identity
    else:
        if beam_size == 1:
            abstractor = Abstractor(join(model_dir, 'abstractor'), max_len,
                                    cuda)
        else:
            abstractor = BeamAbstractor(join(model_dir, 'abstractor'), max_len,
                                        cuda)
    extractor = RLExtractor(model_dir, cuda=cuda)

    # setup loader
    def coll(batch):
        articles = list(filter(bool, batch))
        return articles

    dataset = DecodeDataset(split, DATASET_DIR)

    n_data = len(dataset)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=coll)

    # prepare save paths and logs
    os.makedirs(join(save_path, 'output'), exist_ok=True)
    dec_log = {
        'abstractor': meta['net_args']['abstractor'],
        'extractor': meta['net_args']['extractor'],
        'rl': True,
        'split': split,
        'beam': beam_size,
        'diverse': diverse
    }
    with open(join(save_path, 'log.json'), 'w') as f:
        json.dump(dec_log, f, indent=4)

    # Decoding
    i = 0
    with torch.no_grad():
        for i_debug, raw_article_batch in enumerate(loader):
            tokenized_article_batch = map(tokenize(None), raw_article_batch)
            ext_arts = []
            ext_inds = []
            for raw_art_sents in tokenized_article_batch:
                ext = extractor(raw_art_sents)[:-1]  # exclude EOE
                if not ext:
                    # use top-5 if nothing is extracted
                    # in some rare cases rnn-ext does not extract at all
                    ext = list(range(5))[:len(raw_art_sents)]
                else:
                    ext = [i.item() for i in ext]
                ext_inds += [(len(ext_arts), len(ext))]
                ext_arts += [raw_art_sents[i] for i in ext]
            if beam_size > 1:
                all_beams = abstractor(ext_arts, beam_size, diverse)
                dec_outs = rerank_mp(all_beams, ext_inds)
            else:
                dec_outs = abstractor(ext_arts)
            assert i == batch_size * i_debug
            for j, n in ext_inds:
                decoded_sents = [' '.join(dec) for dec in dec_outs[j:j + n]]
                with open(join(save_path, 'output/{}.dec'.format(i)),
                          'w') as f:
                    f.write(make_html_safe('\n'.join(decoded_sents)))
                i += 1
                print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
                    i, n_data, i / n_data * 100,
                    timedelta(seconds=int(time() - start))),
                      end='')
    print()