예제 #1
0
def batchify_fn_subgraph_nobert(pad,
                                data,
                                cuda=True,
                                adj_type='edge_as_node',
                                mask_type='none',
                                model_type='gat'):
    assert adj_type in [
        'no_edge', 'edge_up', 'edge_down', 'concat_triple', 'edge_as_node'
    ]
    source_lists, targets, source_articles, nodes, sum_worthy, relations, triples, node_lists, dec_selection_mask, \
    sent_align_paras, segment_feat_sent, segment_feat_para, nodefreq, word_inpara_freq, sent_word_inpara_freq = tuple(map(list, unzip(data)))
    if adj_type == 'edge_as_node':
        batch_adjs = list(
            map(subgraph_make_adj_edge_in(cuda=cuda), zip(triples,
                                                          node_lists)))
    else:
        batch_adjs = list(
            map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists)))
    # print('adj:', batch_adjs[0][0])
    # print('node list:', node_lists[0][0])
    # print('triple:', triples[0][0])

    src_nums = [len(source_list) for source_list in source_lists]
    source_articles = pad_batch_tensorize(source_articles, pad=pad, cuda=cuda)
    segment_feat_para = pad_batch_tensorize(segment_feat_para,
                                            pad=pad,
                                            cuda=cuda)
    #sources = list(map(pad_batch_tensorize(pad=pad, cuda=cuda), source_lists))
    sources = list(
        map(pad_batch_tensorize(pad=pad, cuda=cuda, max_num=5), source_lists))
    segment_feat_sent = list(
        map(pad_batch_tensorize(pad=pad, cuda=cuda, max_num=5),
            segment_feat_sent))

    nodefreq = pad_batch_tensorize(nodefreq, pad=pad, cuda=cuda)
    word_inpara_freq = pad_batch_tensorize(word_inpara_freq,
                                           pad=pad,
                                           cuda=cuda)
    sent_word_inpara_freq = list(
        map(pad_batch_tensorize(pad=pad, cuda=cuda, max_num=5),
            sent_word_inpara_freq))
    # source_lists = [source for source_list in source_lists for source in source_list]
    # sources = pad_batch_tensorize(source_lists, pad=pad, cuda=cuda)
    #print('extracted labels:', extracted_labels)

    sum_worthy_label = pad_batch_tensorize(sum_worthy, pad=-1, cuda=cuda)
    sum_worthy = pad_batch_tensorize(sum_worthy, pad=0, cuda=cuda).float()

    dec_selection_mask = pad_batch_tensorize(dec_selection_mask,
                                             pad=0,
                                             cuda=cuda).float()

    node_num = [len(_node) for _node in nodes]
    _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda)
    _relations = pad_batch_tensorize_3d(relations, pad=0, cuda=cuda)
    nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float()
    rmask = pad_batch_tensorize_3d(relations, pad=-1, cuda=cuda).ne(-1).float()
    segment_features = pad_batch_tensorize

    # PAD is -1 (dummy extraction index) for using sequence loss
    target = pad_batch_tensorize(targets, pad=-1, cuda=cuda)
    remove_last = lambda tgt: tgt[:-1]
    tar_in = pad_batch_tensorize(
        list(map(remove_last, targets)),
        pad=-0,
        cuda=cuda  # use 0 here for feeding first conv sentence repr.
    )
    feature_dict = {
        'seg_para': segment_feat_para,
        'seg_sent': segment_feat_sent,
        'sent_inpara_freq': sent_word_inpara_freq,
        'word_inpara_freq': word_inpara_freq,
        'node_freq': nodefreq
    }

    fw_args = (src_nums, tar_in, (sources, source_articles, feature_dict),
               (_nodes, nmask, node_num, sum_worthy,
                dec_selection_mask), (_relations, rmask, triples, batch_adjs,
                                      node_lists, sent_align_paras))
    if 'soft' in mask_type:
        loss_args = (target, sum_worthy_label)
    # elif decoder_supervision:
    #     loss_args = (target, extracted_labels)
    else:
        loss_args = (target, )
    return fw_args, loss_args
예제 #2
0
def batchify_fn_graph_rl(pad,
                         start,
                         end,
                         data,
                         cuda=True,
                         adj_type='concat_triple',
                         docgraph=True,
                         reward_data_dir=None):
    if reward_data_dir is not None:
        batch, ext_word2id, raw_articles, raw_targets, questions = data
    else:
        batch, ext_word2id, raw_articles, raw_targets = data
        questions = []
    if docgraph:
        sources, ext_srcs, nodes, word_freq_feat, nodefreq, relations, triples = tuple(
            map(list, unzip(batch)))
        if adj_type == 'concat_triple':
            adjs = [
                make_adj_triple(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        elif adj_type == 'edge_as_node':
            adjs = [
                make_adj_edge_in(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        else:
            adjs = [
                make_adj(triple, len(node), len(node), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
    else:
        sources, ext_srcs, nodes, word_freq_feat, nodefreq, relations, triples, node_lists = tuple(
            map(list, unzip(batch)))
        if adj_type == 'edge_as_node':
            adjs = list(
                map(subgraph_make_adj_edge_in(cuda=cuda),
                    zip(triples, node_lists)))
        else:
            adjs = list(
                map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists)))

    nodefreq = pad_batch_tensorize(nodefreq, pad=pad, cuda=cuda)
    word_freq = pad_batch_tensorize(word_freq_feat, pad=pad, cuda=cuda)
    feature_dict = {'word_inpara_freq': word_freq, 'node_freq': nodefreq}
    node_num = [len(_node) for _node in nodes]
    _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda)
    nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float()

    src_lens = [len(src) for src in sources]
    sources = [src for src in sources]
    ext_srcs = [ext for ext in ext_srcs]

    source = pad_batch_tensorize(sources, pad, cuda)
    ext_src = pad_batch_tensorize(ext_srcs, pad, cuda)

    ext_vsize = ext_src.max().item() + 1
    extend_vsize = len(ext_word2id)
    ext_id2word = {_id: _word for _word, _id in ext_word2id.items()}
    #print('ext_size:', ext_vsize, extend_vsize)
    if docgraph:
        fw_args = (source, src_lens, ext_src, extend_vsize, _nodes, nmask,
                   node_num, feature_dict, adjs, START, END, UNK, 100)
    else:
        fw_args = (source, src_lens, ext_src, extend_vsize, _nodes, nmask,
                   node_num, feature_dict, node_lists, adjs, START, END, UNK,
                   100)

    loss_args = (raw_articles, ext_id2word, raw_targets, questions)

    return fw_args, loss_args
예제 #3
0
def batchify_fn_gat_copy_from_graph(pad,
                                    start,
                                    end,
                                    data,
                                    cuda=True,
                                    adj_type='concat_triple',
                                    mask_type='none',
                                    decoder_supervision=False):
    sources, ext_srcs, tar_ins, targets, \
    nodes, nodelengths, sum_worthy, relations, rlengths, triples, \
    all_node_words, ext_node_aligns, gold_copy_mask = tuple(map(list, unzip(data)))
    if adj_type == 'concat_triple':
        adjs = [
            make_adj_triple(triple, len(node), len(relation), cuda)
            for triple, node, relation in zip(triples, nodes, relations)
        ]
    elif adj_type == 'edge_as_node':
        adjs = [
            make_adj_edge_in(triple, len(node), len(relation), cuda)
            for triple, node, relation in zip(triples, nodes, relations)
        ]
    else:
        adjs = [
            make_adj(triple, len(node), len(node), cuda)
            for triple, node, relation in zip(triples, nodes, relations)
        ]

    src_lens = [len(src) for src in sources]
    sources = [src for src in sources]
    ext_srcs = [ext for ext in ext_srcs]

    tar_ins = [[start] + tgt for tgt in tar_ins]
    targets = [tgt + [end] for tgt in targets]

    source = pad_batch_tensorize(sources, pad, cuda)
    tar_in = pad_batch_tensorize(tar_ins, pad, cuda)
    target = pad_batch_tensorize(targets, pad, cuda)
    ext_src = pad_batch_tensorize(ext_srcs, pad, cuda)
    all_node_word = pad_batch_tensorize(all_node_words, pad, cuda)
    all_node_mask = pad_batch_tensorize(all_node_words, pad=-1,
                                        cuda=cuda).ne(-1).float()
    ext_node_aligns = pad_batch_tensorize(ext_node_aligns, pad=0, cuda=cuda)
    gold_copy_mask = pad_batch_tensorize(gold_copy_mask, pad=0,
                                         cuda=cuda).float()

    sum_worthy_label = pad_batch_tensorize(sum_worthy, pad=-1, cuda=cuda)
    sum_worthy = pad_batch_tensorize(sum_worthy, pad=0, cuda=cuda).float()

    node_num = [len(_node) for _node in nodes]
    _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda)
    _relations = pad_batch_tensorize_3d(relations, pad=0, cuda=cuda)
    nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float()
    rmask = pad_batch_tensorize_3d(relations, pad=-1, cuda=cuda).ne(-1).float()

    ext_vsize = ext_src.max().item() + 1
    fw_args = (source, src_lens, tar_in, ext_src, ext_vsize,
               (_nodes, nmask, node_num, sum_worthy), (_relations, rmask,
                                                       triples, adjs),
               (all_node_word, all_node_mask, ext_node_aligns, gold_copy_mask))
    if 'soft' in mask_type and decoder_supervision:
        raise Exception('not implemented yet')
        #loss_args = (target, sum_worthy_label, extracted_labels)
    elif 'soft' in mask_type:
        loss_args = (target, sum_worthy_label)
    elif decoder_supervision:
        raise Exception('not implemented yet')
        #loss_args = (target, extracted_labels)
    else:
        loss_args = (target, )
    return fw_args, loss_args
예제 #4
0
def batchify_fn_gat(pad,
                    start,
                    end,
                    data,
                    cuda=True,
                    adj_type='concat_triple',
                    mask_type='none',
                    decoder_supervision=False,
                    docgraph=True):
    sources, ext_srcs, tar_ins, targets, \
    nodes, nodelengths, sum_worthy, word_freq_feat, nodefreq, relations, rlengths, triples = tuple(map(list, unzip(data)))
    if not docgraph:
        node_lists = nodelengths
        if adj_type == 'edge_as_node':
            adjs = list(
                map(subgraph_make_adj_edge_in(cuda=cuda),
                    zip(triples, node_lists)))
        else:
            adjs = list(
                map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists)))
    else:
        if adj_type == 'concat_triple':
            adjs = [
                make_adj_triple(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        elif adj_type == 'edge_as_node':
            adjs = [
                make_adj_edge_in(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        else:
            adjs = [
                make_adj(triple, len(node), len(node), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]

    src_lens = [len(src) for src in sources]
    sources = [src for src in sources]
    ext_srcs = [ext for ext in ext_srcs]

    tar_ins = [[start] + tgt for tgt in tar_ins]
    targets = [tgt + [end] for tgt in targets]

    nodefreq = pad_batch_tensorize(nodefreq, pad=pad, cuda=cuda)
    word_freq = pad_batch_tensorize(word_freq_feat, pad=pad, cuda=cuda)
    feature_dict = {'word_inpara_freq': word_freq, 'node_freq': nodefreq}

    source = pad_batch_tensorize(sources, pad, cuda)
    tar_in = pad_batch_tensorize(tar_ins, pad, cuda)
    target = pad_batch_tensorize(targets, pad, cuda)
    ext_src = pad_batch_tensorize(ext_srcs, pad, cuda)

    sum_worthy_label = pad_batch_tensorize(sum_worthy, pad=-1, cuda=cuda)
    sum_worthy = pad_batch_tensorize(sum_worthy, pad=0, cuda=cuda).float()

    node_num = [len(_node) for _node in nodes]
    _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda)
    _relations = pad_batch_tensorize_3d(relations, pad=0, cuda=cuda)
    nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float()
    rmask = pad_batch_tensorize_3d(relations, pad=-1, cuda=cuda).ne(-1).float()

    ext_vsize = ext_src.max().item() + 1
    if docgraph:
        fw_args = (source, src_lens, tar_in, ext_src, ext_vsize,
                   (_nodes, nmask, node_num, sum_worthy,
                    feature_dict), (_relations, rmask, triples, adjs))
    else:
        fw_args = (source, src_lens, tar_in, ext_src, ext_vsize,
                   (_nodes, nmask, node_num, sum_worthy, feature_dict,
                    node_lists), (_relations, rmask, triples, adjs))
    if 'soft' in mask_type and decoder_supervision:
        raise Exception('not implemented yet')
        #loss_args = (target, sum_worthy_label, extracted_labels)
    elif 'soft' in mask_type:
        loss_args = (target, sum_worthy_label)
    elif decoder_supervision:
        raise Exception('not implemented yet')
        #loss_args = (target, extracted_labels)
    else:
        loss_args = (target, )
    return fw_args, loss_args
예제 #5
0
def batchify_fn_gat_bert(tokenizer,
                         data,
                         cuda=True,
                         adj_type='concat_triple',
                         mask_type='none',
                         docgraph=True):
    sources, ext_srcs, tar_ins, targets, \
    nodes, nodelengths, sum_worthy, nodefreq, relations, rlengths, triples, src_lens = (data[0], ) + tuple(map(list, unzip(data[1])))
    start = tokenizer.encoder[tokenizer._bos_token]
    end = tokenizer.encoder[tokenizer._eos_token]
    pad = tokenizer.encoder[tokenizer._pad_token]

    if not docgraph:
        node_lists = nodelengths
        if adj_type == 'edge_as_node':
            adjs = list(
                map(subgraph_make_adj_edge_in(cuda=cuda),
                    zip(triples, node_lists)))
        else:
            adjs = list(
                map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists)))
    else:
        if adj_type == 'concat_triple':
            adjs = [
                make_adj_triple(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        elif adj_type == 'edge_as_node':
            adjs = [
                make_adj_edge_in(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        else:
            adjs = [
                make_adj(triple, len(node), len(node), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]

    #src_lens = [len(src) for src in sources]
    sources = [src for src in sources]
    ext_srcs = [ext for ext in ext_srcs]

    tar_ins = [[start] + tgt for tgt in tar_ins]
    targets = [tgt + [end] for tgt in targets]

    nodefreq = pad_batch_tensorize(nodefreq, pad=pad, cuda=cuda)
    feature_dict = {'node_freq': nodefreq}

    source = pad_batch_tensorize(sources, pad, cuda)
    tar_in = pad_batch_tensorize(tar_ins, pad, cuda)
    target = pad_batch_tensorize(targets, pad, cuda)
    ext_src = pad_batch_tensorize(ext_srcs, pad, cuda)

    sum_worthy_label = pad_batch_tensorize(sum_worthy, pad=-1, cuda=cuda)
    sum_worthy = pad_batch_tensorize(sum_worthy, pad=0, cuda=cuda).float()

    node_num = [len(_node) for _node in nodes]
    _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda)
    _relations = pad_batch_tensorize_3d(relations, pad=0, cuda=cuda)
    nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float()
    rmask = pad_batch_tensorize_3d(relations, pad=-1, cuda=cuda).ne(-1).float()

    ext_vsize = ext_src.max().item() + 1
    if docgraph:
        fw_args = (source, src_lens, tar_in, ext_src, ext_vsize,
                   (_nodes, nmask, node_num, sum_worthy,
                    feature_dict), (_relations, rmask, triples, adjs))
    else:
        fw_args = (source, src_lens, tar_in, ext_src, ext_vsize,
                   (_nodes, nmask, node_num, sum_worthy, feature_dict,
                    node_lists), (_relations, rmask, triples, adjs))

    if 'soft' in mask_type:
        loss_args = (target, sum_worthy_label)
    else:
        loss_args = (target, )
    return fw_args, loss_args
예제 #6
0
def batchify_fn_graph_rl_bert(tokenizer,
                              data,
                              cuda=True,
                              adj_type='concat_triple',
                              docgraph=True,
                              reward_data_dir=None):
    start = tokenizer.encoder[tokenizer._bos_token]
    end = tokenizer.encoder[tokenizer._eos_token]
    pad = tokenizer.encoder[tokenizer._pad_token]
    unk = tokenizer.encoder[tokenizer._unk_token]
    if reward_data_dir is not None:
        batch, ext_word2id, raw_articles, raw_targets, questions = data
    else:
        batch, ext_word2id, raw_articles, raw_targets = data

        questions = []
    if docgraph:
        sources, ext_srcs, nodes, nodefreq, relations, triples, src_lens, tar_ins, targets = (
            batch[0], ) + tuple(map(list, unzip(batch[1])))
        if adj_type == 'concat_triple':
            adjs = [
                make_adj_triple(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        elif adj_type == 'edge_as_node':
            adjs = [
                make_adj_edge_in(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        else:
            adjs = [
                make_adj(triple, len(node), len(node), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
    else:
        sources, ext_srcs, nodes, nodefreq, relations, triples, node_lists, src_lens, tar_ins, targets = (
            batch[0], ) + tuple(map(list, unzip(batch[1])))
        if adj_type == 'edge_as_node':
            adjs = list(
                map(subgraph_make_adj_edge_in(cuda=cuda),
                    zip(triples, node_lists)))
        else:
            adjs = list(
                map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists)))

    nodefreq = pad_batch_tensorize(nodefreq, pad=0, cuda=cuda)
    feature_dict = {'node_freq': nodefreq}
    node_num = [len(_node) for _node in nodes]
    _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda)
    nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float()

    tar_ins = [[start] + tgt for tgt in tar_ins]
    targets = [tgt + [end] for tgt in targets]
    tar_in = pad_batch_tensorize(tar_ins, pad, cuda)
    target = pad_batch_tensorize(targets, pad, cuda)

    #src_lens = [len(src) for src in sources]
    sources = [src for src in sources]
    ext_srcs = [ext for ext in ext_srcs]

    source = pad_batch_tensorize(sources, pad, cuda)
    ext_src = pad_batch_tensorize(ext_srcs, pad, cuda)

    ext_vsize = ext_src.max().item() + 1
    extend_vsize = len(ext_word2id)
    ext_id2word = {_id: _word for _word, _id in ext_word2id.items()}
    #print('ext_size:', ext_vsize, extend_vsize)
    if docgraph:
        fw_args = (source, src_lens, ext_src, extend_vsize, _nodes, nmask,
                   node_num, feature_dict, adjs, start, end, unk, 150, tar_in)
    else:
        fw_args = (source, src_lens, ext_src, extend_vsize, _nodes, nmask,
                   node_num, feature_dict, node_lists, adjs, start, end, unk,
                   150, tar_in)

    loss_args = (raw_articles, ext_id2word, raw_targets, questions, target)

    return fw_args, loss_args