Exemplo n.º 1
0
def build_batchers(word2id, cuda, debug):
    prepro = prepro_fn(args.max_art, args.max_abs)

    def sort_key(sample):
        src, target = sample
        return (len(target), len(src))

    batchify = compose(batchify_fn_copy(PAD, START, END, cuda=cuda),
                       convert_batch_copy(UNK, word2id))

    train_loader = DataLoader(MatchDataset('train'),
                              batch_size=BUCKET_SIZE,
                              shuffle=not debug,
                              num_workers=4 if cuda and not debug else 0,
                              collate_fn=coll_fn)
    train_batcher = BucketedGenerater(train_loader,
                                      prepro,
                                      sort_key,
                                      batchify,
                                      single_run=False,
                                      fork=not debug)

    val_loader = DataLoader(MatchDataset('val'),
                            batch_size=BUCKET_SIZE,
                            shuffle=False,
                            num_workers=4 if cuda and not debug else 0,
                            collate_fn=coll_fn)
    val_batcher = BucketedGenerater(val_loader,
                                    prepro,
                                    sort_key,
                                    batchify,
                                    single_run=True,
                                    fork=not debug)
    return train_batcher, val_batcher
Exemplo n.º 2
0
def build_batchers(net_type, word2id, cuda, debug, use_bert, bert_tokenizer):
    assert net_type in ['ff', 'rnn']
    def sort_key(sample):
        src_sents, _ = sample
        return len(src_sents)

    if not use_bert:
        prepro = prepro_fn_extract(args.max_word, args.max_sent)
        batchify_fn = (batchify_fn_extract_ff if net_type == 'ff'
                   else batchify_fn_extract_ptr)
        convert_batch = (convert_batch_extract_ff if net_type == 'ff'
                        else convert_batch_extract_ptr)
        batchify = compose(batchify_fn(PAD, cuda=cuda),
                       convert_batch(UNK, word2id))

    else:
        # prepro = prepro_fn_extract(args.max_word, args.max_sent)
        # batchify_fn = batchify_fn_bert_extract_ptr2
        # convert_batch = convert_batch_bert_extract_ptr2
        # batchify = compose(batchify_fn(bert_tokenizer.pad_token_id, cuda=cuda),
        #                 convert_batch(bert_tokenizer))

        prepro = prepro_fn_identity
        batchify_fn = batchify_fn_bert_extract_ptr2
        convert_batch = convert_batch_bert_extract_ptr3
        batchify = compose(batchify_fn(bert_tokenizer.pad_token_id, cuda=cuda),
                        convert_batch(bert_tokenizer, max_len=args.max_word, max_sent=args.max_sent))


    train_loader = DataLoader(
        ExtractDataset('train'), batch_size=BUCKET_SIZE,
        shuffle=not debug,
        num_workers=4 if cuda and not debug else 0,
        collate_fn=coll_fn_extract
    )
    train_batcher = BucketedGenerater(train_loader, prepro, sort_key, batchify,
                                      single_run=False, fork=not debug)

    val_loader = DataLoader(
        ExtractDataset('val'), batch_size=BUCKET_SIZE,
        shuffle=False, num_workers=4 if cuda and not debug else 0,
        collate_fn=coll_fn_extract
    )
    val_batcher = BucketedGenerater(val_loader, prepro, sort_key, batchify,
                                    single_run=True, fork=not debug)
    return train_batcher, val_batcher
Exemplo n.º 3
0
def build_batchers(net_type, word2id, cuda, debug):
    assert net_type in ['ff', 'rnn', 'trans_rnn']
    prepro = prepro_fn_extract(args.max_word, args.max_sent)

    def sort_key(sample):
        src_sents, _ = sample
        return len(src_sents)

    if net_type == 'trans_rnn':
        prepro = prepro_fn_extract_trans(args.max_word, args.max_sent)
        batchify = compose(batchify_fn_extract_trans(cuda=cuda),
                           convert_batch_extract_trans)
    else:
        prepro = prepro_fn_extract(args.max_word, args.max_sent)
        batchify_fn = (batchify_fn_extract_ff
                       if net_type == 'ff' else batchify_fn_extract_ptr)
        convert_batch = (convert_batch_extract_ff
                         if net_type == 'ff' else convert_batch_extract_ptr)
        batchify = compose(batchify_fn(PAD, cuda=cuda),
                           convert_batch(UNK, word2id))

    train_loader = DataLoader(ExtractDataset('train'),
                              batch_size=BUCKET_SIZE,
                              shuffle=not debug,
                              num_workers=4 if cuda and not debug else 0,
                              collate_fn=coll_fn_extract)
    train_batcher = BucketedGenerater(train_loader,
                                      prepro,
                                      sort_key,
                                      batchify,
                                      single_run=False,
                                      fork=not debug)

    val_loader = DataLoader(ExtractDataset('val'),
                            batch_size=BUCKET_SIZE,
                            shuffle=False,
                            num_workers=4 if cuda and not debug else 0,
                            collate_fn=coll_fn_extract)
    val_batcher = BucketedGenerater(val_loader,
                                    prepro,
                                    sort_key,
                                    batchify,
                                    single_run=True,
                                    fork=not debug)
    return train_batcher, val_batcher
Exemplo n.º 4
0
def build_batchers_entity(net_type, word2id, cuda, debug):
    assert net_type in ['entity']

    prepro = prepro_fn_extract_entity(args.max_word, args.max_sent)

    # def sort_key(sample):
    #     src_sents, _, _ = sample
    #     return len(src_sents)
    def sort_key(sample):
        src_sents = sample[0]
        return len(src_sents)


    key = 'filtered_rule23_6_input_mention_cluster'


    batchify_fn = batchify_fn_extract_ptr_entity
    convert_batch = convert_batch_extract_ptr_entity


    batchify = compose(batchify_fn(PAD, cuda=cuda),
                       convert_batch(UNK, word2id))

    train_loader = DataLoader(
        EntityExtractDataset_combine('train', key), batch_size=BUCKET_SIZE,
        shuffle=not debug,
        num_workers=4 if cuda and not debug else 0,
        collate_fn=coll_fn_extract_entity
    )
    train_batcher = BucketedGenerater(train_loader, prepro, sort_key, batchify,
                                      single_run=False, fork=not debug)

    val_loader = DataLoader(
        EntityExtractDataset_combine('val', key), batch_size=BUCKET_SIZE,
        shuffle=False, num_workers=4 if cuda and not debug else 0,
        collate_fn=coll_fn_extract_entity
    )
    val_batcher = BucketedGenerater(val_loader, prepro, sort_key, batchify,
                                    single_run=True, fork=not debug)

    return train_batcher, val_batcher
Exemplo n.º 5
0
def build_batchers_bert(cuda, debug, bert_model):
    tokenizer = RobertaTokenizer.from_pretrained(bert_model)
    #tokenizer = BertTokenizer.from_pretrained(bert_model)
    prepro = prepro_fn_copy_bert(tokenizer, args.max_art, args.max_abs)

    def sort_key(sample):
        src, target = sample[0], sample[1]
        return (len(target), len(src))

    batchify = compose(batchify_fn_copy_bert(tokenizer, cuda=cuda),
                       convert_batch_copy_bert(tokenizer, args.max_art))

    train_loader = DataLoader(SumDataset('train'),
                              batch_size=BUCKET_SIZE,
                              shuffle=not debug,
                              num_workers=4 if cuda and not debug else 0,
                              collate_fn=coll_fn)
    train_batcher = BucketedGenerater(train_loader,
                                      prepro,
                                      sort_key,
                                      batchify,
                                      single_run=False,
                                      fork=not debug)
    val_loader = DataLoader(SumDataset('val'),
                            batch_size=BUCKET_SIZE,
                            shuffle=False,
                            num_workers=4 if cuda and not debug else 0,
                            collate_fn=coll_fn)
    val_batcher = BucketedGenerater(val_loader,
                                    prepro,
                                    sort_key,
                                    batchify,
                                    single_run=True,
                                    fork=not debug)

    return train_batcher, val_batcher, tokenizer.encoder
def build_batchers(decoder, emb_type, word2id, cuda, debug):
    prepro = prepro_fn_extract(args.max_word, args.max_sent, emb_type)

    def sort_key(sample):
        src_sents, _ = sample
        return len(src_sents)

    batchify_fn = batchify_fn_extract_ptr
    convert_batch = convert_batch_extract_ptr
    batchify = compose(batchify_fn(PAD, cuda=cuda),
                       convert_batch(UNK, word2id, emb_type))

    train_loader = DataLoader(ExtractDataset('train'),
                              batch_size=BUCKET_SIZE,
                              shuffle=not debug,
                              num_workers=4 if cuda and not debug else 0,
                              collate_fn=coll_fn_extract)
    train_batcher = BucketedGenerater(train_loader,
                                      prepro,
                                      sort_key,
                                      batchify,
                                      single_run=False,
                                      fork=not debug)

    val_loader = DataLoader(ExtractDataset('val'),
                            batch_size=BUCKET_SIZE,
                            shuffle=False,
                            num_workers=4 if cuda and not debug else 0,
                            collate_fn=coll_fn_extract)
    val_batcher = BucketedGenerater(val_loader,
                                    prepro,
                                    sort_key,
                                    batchify,
                                    single_run=True,
                                    fork=not debug)
    return train_batcher, val_batcher
def build_batchers(net_type, word2id, cuda, debug, if_neusum=False, stop=False, combine=False):
    assert net_type in ['ff', 'rnn', 'nnse']
    assert not (combine and if_neusum)
    prepro = prepro_fn_extract(args.max_word, args.max_sent)
    def sort_key(sample):
        src_sents, _ = sample
        return len(src_sents)
    if stop:
        print('add stop')
        batchify_fn = (batchify_fn_extract_ff if net_type == 'ff'
                       else batchify_fn_extract_ptr)
        if net_type == 'nnse':
            batchify_fn = batchify_fn_extract_nnse
        convert_batch = (convert_batch_extract_ff if net_type in ['ff', 'nnse']
                         else convert_batch_extract_ptr_stop)
    else:
        batchify_fn = (batchify_fn_extract_ff if net_type == 'ff'
                       else batchify_fn_extract_ptr)
        if net_type == 'nnse':
            batchify_fn = batchify_fn_extract_nnse
        convert_batch = (convert_batch_extract_ff if net_type in ['ff', 'nnse']
                         else convert_batch_extract_ptr)
    batchify = compose(batchify_fn(PAD, cuda=cuda),
                       convert_batch(UNK, word2id))

    if if_neusum:
        print('Use neusum constrcution')
        train_loader = DataLoader(
            ExtractDataset_neusum('train'), batch_size=BUCKET_SIZE,
            shuffle=not debug,
            num_workers=4 if cuda and not debug else 0,
            collate_fn=coll_fn_extract
        )
        train_batcher = BucketedGenerater(train_loader, prepro, sort_key, batchify,
                                          single_run=False, fork=not debug)

        val_loader = DataLoader(
            ExtractDataset_neusum('val'), batch_size=BUCKET_SIZE,
            shuffle=False, num_workers=4 if cuda and not debug else 0,
            collate_fn=coll_fn_extract
        )
        val_batcher = BucketedGenerater(val_loader, prepro, sort_key, batchify,
                                    single_run=True, fork=not debug)
    elif combine:
        print('Use combine constrcution')
        train_loader = DataLoader(
            ExtractDataset_combine('train'), batch_size=BUCKET_SIZE,
            shuffle=not debug,
            num_workers=4 if cuda and not debug else 0,
            collate_fn=coll_fn_extract
        )
        train_batcher = BucketedGenerater(train_loader, prepro, sort_key, batchify,
                                          single_run=False, fork=not debug)

        val_loader = DataLoader(
            ExtractDataset_combine('val'), batch_size=BUCKET_SIZE,
            shuffle=False, num_workers=4 if cuda and not debug else 0,
            collate_fn=coll_fn_extract
        )
        val_batcher = BucketedGenerater(val_loader, prepro, sort_key, batchify,
                                        single_run=True, fork=not debug)
    else:
        train_loader = DataLoader(
            ExtractDataset('train'), batch_size=BUCKET_SIZE,
            shuffle=not debug,
            num_workers=4 if cuda and not debug else 0,
            collate_fn=coll_fn_extract
        )
        train_batcher = BucketedGenerater(train_loader, prepro, sort_key, batchify,
                                          single_run=False, fork=not debug)

        val_loader = DataLoader(
            ExtractDataset('val'), batch_size=BUCKET_SIZE,
            shuffle=False, num_workers=4 if cuda and not debug else 0,
            collate_fn=coll_fn_extract
        )
        val_batcher = BucketedGenerater(val_loader, prepro, sort_key, batchify,
                                        single_run=True, fork=not debug)
    return train_batcher, val_batcher
Exemplo n.º 8
0
def build_batchers_gat_bert(cuda,
                            debug,
                            gold_key,
                            adj_type,
                            mask_type,
                            subgraph,
                            num_worker=4,
                            bert_model='roberta-base'):
    print('adj_type:', adj_type)
    print('mask_type:', mask_type)
    docgraph = not subgraph
    tokenizer = RobertaTokenizer.from_pretrained(bert_model)
    #tokenizer = BertTokenizer.from_pretrained(bert_model)

    with open(os.path.join(DATA_DIR, 'roberta-base-align.pkl'), 'rb') as f:
        align = pickle.load(f)

    prepro = prepro_fn_gat_bert(tokenizer,
                                align,
                                args.max_art,
                                args.max_abs,
                                key=gold_key,
                                adj_type=adj_type,
                                docgraph=docgraph)
    if not subgraph:
        key = 'nodes_pruned2'
        _coll_fn = coll_fn_gat(max_node_num=200)
    else:
        key = 'nodes'
        _coll_fn = coll_fn_gat(max_node_num=400)

    def sort_key(sample):
        src, target = sample[0], sample[1]
        return (len(target), len(src))

    batchify = compose(
        batchify_fn_gat_bert(tokenizer,
                             cuda=cuda,
                             adj_type=adj_type,
                             mask_type=mask_type,
                             docgraph=docgraph),
        convert_batch_gat_bert(tokenizer, args.max_art))

    train_loader = DataLoader(
        MatchDataset_graph('train', key=key, subgraph=subgraph),
        batch_size=BUCKET_SIZE,
        shuffle=not debug,
        num_workers=num_worker if cuda and not debug else 0,
        collate_fn=_coll_fn)
    train_batcher = BucketedGenerater(train_loader,
                                      prepro,
                                      sort_key,
                                      batchify,
                                      single_run=False,
                                      fork=not debug)
    val_loader = DataLoader(
        MatchDataset_graph('val', key=key, subgraph=subgraph),
        batch_size=BUCKET_SIZE,
        shuffle=False,
        num_workers=num_worker if cuda and not debug else 0,
        collate_fn=_coll_fn)
    val_batcher = BucketedGenerater(val_loader,
                                    prepro,
                                    sort_key,
                                    batchify,
                                    single_run=True,
                                    fork=not debug)

    return train_batcher, val_batcher, tokenizer.encoder
Exemplo n.º 9
0
def build_batchers_gat(word2id,
                       cuda,
                       debug,
                       gold_key,
                       adj_type,
                       mask_type,
                       subgraph,
                       num_worker=4):
    print('adj_type:', adj_type)
    print('mask_type:', mask_type)
    docgraph = not subgraph
    prepro = prepro_fn_gat(args.max_art,
                           args.max_abs,
                           key=gold_key,
                           adj_type=adj_type,
                           docgraph=docgraph)
    if not subgraph:
        key = 'nodes_pruned2'
        _coll_fn = coll_fn_gat(max_node_num=200)
    else:
        key = 'nodes'
        _coll_fn = coll_fn_gat(max_node_num=400)

    def sort_key(sample):
        src, target = sample[0], sample[1]
        return (len(target), len(src))

    batchify = compose(
        batchify_fn_gat(PAD,
                        START,
                        END,
                        cuda=cuda,
                        adj_type=adj_type,
                        mask_type=mask_type,
                        docgraph=docgraph), convert_batch_gat(UNK, word2id))

    train_loader = DataLoader(
        MatchDataset_graph('train', key=key, subgraph=subgraph),
        batch_size=BUCKET_SIZE,
        shuffle=not debug,
        num_workers=num_worker if cuda and not debug else 0,
        collate_fn=_coll_fn)
    train_batcher = BucketedGenerater(train_loader,
                                      prepro,
                                      sort_key,
                                      batchify,
                                      single_run=False,
                                      fork=not debug)
    val_loader = DataLoader(
        MatchDataset_graph('val', key=key, subgraph=subgraph),
        batch_size=BUCKET_SIZE,
        shuffle=False,
        num_workers=num_worker if cuda and not debug else 0,
        collate_fn=_coll_fn)
    val_batcher = BucketedGenerater(val_loader,
                                    prepro,
                                    sort_key,
                                    batchify,
                                    single_run=True,
                                    fork=not debug)

    return train_batcher, val_batcher
Exemplo n.º 10
0
def build_batchers_graph_bert(tokenizer, cuda, debug, key, adj_type, docgraph,
                              reward_data_dir):
    #prepro = prepro_graph(args.max_art, args.max_abs, adj_type, docgraph=docgraph, reward_data_dir=reward_data_dir)
    with open(os.path.join(DATA_DIR, 'roberta-base-align.pkl'), 'rb') as f:
        align = pickle.load(f)

    prepro = prepro_graph_bert(tokenizer,
                               align,
                               args.max_art,
                               args.max_abs,
                               adj_type,
                               docgraph=docgraph,
                               reward_data_dir=reward_data_dir)

    def sort_key(sample):
        src, target, nodes = sample[0], sample[1], sample[3]
        return (len(src), len(target), len(nodes))

    batchify = compose(
        batchify_fn_graph_rl_bert(tokenizer,
                                  cuda=cuda,
                                  adj_type=adj_type,
                                  docgraph=docgraph,
                                  reward_data_dir=reward_data_dir),
        convert_batch_graph_rl_bert(tokenizer,
                                    args.max_art,
                                    docgraph=docgraph,
                                    reward_data_dir=reward_data_dir))
    if reward_data_dir is not None:
        if docgraph:
            _coll_fn = coll_fn_graph_rl(max_node=400)
        else:
            _coll_fn = coll_fn_graph_rl(max_node=800)
    else:
        _coll_fn = coll_fn_graph

    train_loader = DataLoader(Dataset_RLgraph('train',
                                              key,
                                              reward_data_dir=reward_data_dir),
                              batch_size=BUCKET_SIZE,
                              shuffle=not debug,
                              num_workers=8 if cuda and not debug else 0,
                              collate_fn=_coll_fn)
    train_batcher = BucketedGenerater(train_loader,
                                      prepro,
                                      sort_key,
                                      batchify,
                                      single_run=False,
                                      fork=not debug)
    val_loader = DataLoader(Dataset_RLgraph('val',
                                            key,
                                            reward_data_dir=reward_data_dir),
                            batch_size=BUCKET_SIZE,
                            shuffle=False,
                            num_workers=8 if cuda and not debug else 0,
                            collate_fn=_coll_fn)
    val_batcher = BucketedGenerater(val_loader,
                                    prepro,
                                    sort_key,
                                    batchify,
                                    single_run=True,
                                    fork=not debug)

    return train_batcher, val_batcher