def batchify_fn_graph_rl(pad,
                         start,
                         end,
                         data,
                         cuda=True,
                         adj_type='concat_triple',
                         docgraph=True,
                         reward_data_dir=None):
    if reward_data_dir is not None:
        batch, ext_word2id, raw_articles, raw_targets, questions = data
    else:
        batch, ext_word2id, raw_articles, raw_targets = data
        questions = []
    if docgraph:
        sources, ext_srcs, nodes, word_freq_feat, nodefreq, relations, triples = tuple(
            map(list, unzip(batch)))
        if adj_type == 'concat_triple':
            adjs = [
                make_adj_triple(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        elif adj_type == 'edge_as_node':
            adjs = [
                make_adj_edge_in(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        else:
            adjs = [
                make_adj(triple, len(node), len(node), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
    else:
        sources, ext_srcs, nodes, word_freq_feat, nodefreq, relations, triples, node_lists = tuple(
            map(list, unzip(batch)))
        if adj_type == 'edge_as_node':
            adjs = list(
                map(subgraph_make_adj_edge_in(cuda=cuda),
                    zip(triples, node_lists)))
        else:
            adjs = list(
                map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists)))

    nodefreq = pad_batch_tensorize(nodefreq, pad=pad, cuda=cuda)
    word_freq = pad_batch_tensorize(word_freq_feat, pad=pad, cuda=cuda)
    feature_dict = {'word_inpara_freq': word_freq, 'node_freq': nodefreq}
    node_num = [len(_node) for _node in nodes]
    _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda)
    nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float()

    src_lens = [len(src) for src in sources]
    sources = [src for src in sources]
    ext_srcs = [ext for ext in ext_srcs]

    source = pad_batch_tensorize(sources, pad, cuda)
    ext_src = pad_batch_tensorize(ext_srcs, pad, cuda)

    ext_vsize = ext_src.max().item() + 1
    extend_vsize = len(ext_word2id)
    ext_id2word = {_id: _word for _word, _id in ext_word2id.items()}
    #print('ext_size:', ext_vsize, extend_vsize)
    if docgraph:
        fw_args = (source, src_lens, ext_src, extend_vsize, _nodes, nmask,
                   node_num, feature_dict, adjs, START, END, UNK, 100)
    else:
        fw_args = (source, src_lens, ext_src, extend_vsize, _nodes, nmask,
                   node_num, feature_dict, node_lists, adjs, START, END, UNK,
                   100)

    loss_args = (raw_articles, ext_id2word, raw_targets, questions)

    return fw_args, loss_args
def batchify_fn_gat(pad,
                    start,
                    end,
                    data,
                    cuda=True,
                    adj_type='concat_triple',
                    mask_type='none',
                    decoder_supervision=False,
                    docgraph=True):
    sources, ext_srcs, tar_ins, targets, \
    nodes, nodelengths, sum_worthy, word_freq_feat, nodefreq, relations, rlengths, triples = tuple(map(list, unzip(data)))
    if not docgraph:
        node_lists = nodelengths
        if adj_type == 'edge_as_node':
            adjs = list(
                map(subgraph_make_adj_edge_in(cuda=cuda),
                    zip(triples, node_lists)))
        else:
            adjs = list(
                map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists)))
    else:
        if adj_type == 'concat_triple':
            adjs = [
                make_adj_triple(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        elif adj_type == 'edge_as_node':
            adjs = [
                make_adj_edge_in(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        else:
            adjs = [
                make_adj(triple, len(node), len(node), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]

    src_lens = [len(src) for src in sources]
    sources = [src for src in sources]
    ext_srcs = [ext for ext in ext_srcs]

    tar_ins = [[start] + tgt for tgt in tar_ins]
    targets = [tgt + [end] for tgt in targets]

    nodefreq = pad_batch_tensorize(nodefreq, pad=pad, cuda=cuda)
    word_freq = pad_batch_tensorize(word_freq_feat, pad=pad, cuda=cuda)
    feature_dict = {'word_inpara_freq': word_freq, 'node_freq': nodefreq}

    source = pad_batch_tensorize(sources, pad, cuda)
    tar_in = pad_batch_tensorize(tar_ins, pad, cuda)
    target = pad_batch_tensorize(targets, pad, cuda)
    ext_src = pad_batch_tensorize(ext_srcs, pad, cuda)

    sum_worthy_label = pad_batch_tensorize(sum_worthy, pad=-1, cuda=cuda)
    sum_worthy = pad_batch_tensorize(sum_worthy, pad=0, cuda=cuda).float()

    node_num = [len(_node) for _node in nodes]
    _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda)
    _relations = pad_batch_tensorize_3d(relations, pad=0, cuda=cuda)
    nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float()
    rmask = pad_batch_tensorize_3d(relations, pad=-1, cuda=cuda).ne(-1).float()

    ext_vsize = ext_src.max().item() + 1
    if docgraph:
        fw_args = (source, src_lens, tar_in, ext_src, ext_vsize,
                   (_nodes, nmask, node_num, sum_worthy,
                    feature_dict), (_relations, rmask, triples, adjs))
    else:
        fw_args = (source, src_lens, tar_in, ext_src, ext_vsize,
                   (_nodes, nmask, node_num, sum_worthy, feature_dict,
                    node_lists), (_relations, rmask, triples, adjs))
    if 'soft' in mask_type and decoder_supervision:
        raise Exception('not implemented yet')
        #loss_args = (target, sum_worthy_label, extracted_labels)
    elif 'soft' in mask_type:
        loss_args = (target, sum_worthy_label)
    elif decoder_supervision:
        raise Exception('not implemented yet')
        #loss_args = (target, extracted_labels)
    else:
        loss_args = (target, )
    return fw_args, loss_args
示例#3
0
def batchify_fn_gat_bert(tokenizer,
                         data,
                         cuda=True,
                         adj_type='concat_triple',
                         mask_type='none',
                         docgraph=True):
    sources, ext_srcs, tar_ins, targets, \
    nodes, nodelengths, sum_worthy, nodefreq, relations, rlengths, triples, src_lens = (data[0], ) + tuple(map(list, unzip(data[1])))
    start = tokenizer.encoder[tokenizer._bos_token]
    end = tokenizer.encoder[tokenizer._eos_token]
    pad = tokenizer.encoder[tokenizer._pad_token]

    if not docgraph:
        node_lists = nodelengths
        if adj_type == 'edge_as_node':
            adjs = list(
                map(subgraph_make_adj_edge_in(cuda=cuda),
                    zip(triples, node_lists)))
        else:
            adjs = list(
                map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists)))
    else:
        if adj_type == 'concat_triple':
            adjs = [
                make_adj_triple(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        elif adj_type == 'edge_as_node':
            adjs = [
                make_adj_edge_in(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        else:
            adjs = [
                make_adj(triple, len(node), len(node), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]

    #src_lens = [len(src) for src in sources]
    sources = [src for src in sources]
    ext_srcs = [ext for ext in ext_srcs]

    tar_ins = [[start] + tgt for tgt in tar_ins]
    targets = [tgt + [end] for tgt in targets]

    nodefreq = pad_batch_tensorize(nodefreq, pad=pad, cuda=cuda)
    feature_dict = {'node_freq': nodefreq}

    source = pad_batch_tensorize(sources, pad, cuda)
    tar_in = pad_batch_tensorize(tar_ins, pad, cuda)
    target = pad_batch_tensorize(targets, pad, cuda)
    ext_src = pad_batch_tensorize(ext_srcs, pad, cuda)

    sum_worthy_label = pad_batch_tensorize(sum_worthy, pad=-1, cuda=cuda)
    sum_worthy = pad_batch_tensorize(sum_worthy, pad=0, cuda=cuda).float()

    node_num = [len(_node) for _node in nodes]
    _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda)
    _relations = pad_batch_tensorize_3d(relations, pad=0, cuda=cuda)
    nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float()
    rmask = pad_batch_tensorize_3d(relations, pad=-1, cuda=cuda).ne(-1).float()

    ext_vsize = ext_src.max().item() + 1
    if docgraph:
        fw_args = (source, src_lens, tar_in, ext_src, ext_vsize,
                   (_nodes, nmask, node_num, sum_worthy,
                    feature_dict), (_relations, rmask, triples, adjs))
    else:
        fw_args = (source, src_lens, tar_in, ext_src, ext_vsize,
                   (_nodes, nmask, node_num, sum_worthy, feature_dict,
                    node_lists), (_relations, rmask, triples, adjs))

    if 'soft' in mask_type:
        loss_args = (target, sum_worthy_label)
    else:
        loss_args = (target, )
    return fw_args, loss_args
示例#4
0
    def forward(self, raw_input, n_abs=None, sample_time=1, validate=False):

        if self._enable_docgraph:
            if self._bert:
                raise NotImplementedError
            else:
                raw_article_sents, nodes, nodefreq, word_freq_feat, sent_word_freq, triples, relations, sent_aligns = raw_input
            if self._adj_type == 'concat_triple':
                adjs = [
                    make_adj_triple(triples, len(nodes), len(relations),
                                    self._cuda)
                ]
            elif self._adj_type == 'edge_as_node':
                adjs = [
                    make_adj_edge_in(triples, len(nodes), len(relations),
                                     self._cuda)
                ]
            else:
                adjs = [make_adj(triples, len(nodes), len(nodes), self._cuda)]
        else:
            if self._bert:
                _, raw_article_sents, nodes, nodefreq, triples, relations, node_lists, word_nums = raw_input
            else:
                raw_article_sents, nodes, nodefreq, word_freq_feat, sent_word_freq, triples, relations, sent_aligns, node_lists = raw_input
            if self._adj_type == 'edge_as_node':
                adjs = [
                    subgraph_make_adj_edge_in((triples, node_lists),
                                              cuda=self._cuda)
                ]
            else:
                adjs = [
                    subgraph_make_adj((triples, node_lists), cuda=self._cuda)
                ]

        if not self._bert:
            sent_word_freq = pad_batch_tensorize(sent_word_freq,
                                                 pad=0,
                                                 max_num=5,
                                                 cuda=self._cuda)
            word_freq_feat = pad_batch_tensorize([word_freq_feat],
                                                 pad=0,
                                                 cuda=self._cuda)

        nodenum = [len(nodes)]
        sentnum = [len(raw_article_sents)]
        nmask = pad_batch_tensorize(
            nodes, pad=-1, cuda=self._cuda).ne(-1).float().unsqueeze(0)
        nodes = pad_batch_tensorize(nodes, pad=0, cuda=self._cuda).unsqueeze(0)
        nodefreq = pad_batch_tensorize([nodefreq], pad=0, cuda=self._cuda)

        if self._bert:
            articles = self._batcher(raw_article_sents)
            articles, article_sent = self._encode_bert(articles, [word_nums])
            enc_sent = self._sent_enc(article_sent[0], None, None,
                                      None).unsqueeze(0)
            enc_art = self._art_enc(enc_sent)
            sent_aligns = None

        else:
            article_sent, articles = self._batcher(raw_article_sents)
            articles = self._sent_enc._embedding(articles)
            sent_aligns = [sent_aligns]
            if self._pe:
                bs, max_art_len, _ = articles.size()
                src_pos = torch.tensor([[i for i in range(max_art_len)]
                                        for _ in range(bs)
                                        ]).to(articles.device)
                src_pos = self._sent_enc.poisition_enc(src_pos)
                articles = torch.cat([articles, src_pos], dim=-1)
            if 'inpara_freq' in self._feature_banks:
                word_inpara_freq = self._sent_enc._inpara_embedding(
                    word_freq_feat)
                articles = torch.cat([articles, word_inpara_freq], dim=-1)
            enc_sent = self._sent_enc(article_sent, None, sent_word_freq,
                                      None).unsqueeze(0)
            enc_art = self._art_enc(enc_sent)

        # print('enc_Art:', enc_art)
        if self._enable_docgraph:
            nodes = self._encode_docgraph(articles,
                                          nodes,
                                          nmask,
                                          adjs,
                                          nodenum,
                                          enc_out=enc_art,
                                          sent_nums=sentnum,
                                          nodefreq=nodefreq)
        else:
            outputs = self._encode_paragraph(articles,
                                             nodes,
                                             nmask,
                                             adjs, [node_lists],
                                             enc_out=enc_art,
                                             sent_nums=sentnum,
                                             nodefreq=nodefreq)
            if self._hierarchical_attn:
                (topics, topic_length), masks, (nodes, node_length,
                                                node_align_paras) = outputs
                node_align_paras = pad_batch_tensorize(node_align_paras,
                                                       pad=0,
                                                       cuda=False).to(
                                                           nodes.device)
            elif 'soft' in self._mask_type:
                (nodes, topic_length), masks = outputs
                topics = None
                node_align_paras = None
            else:
                nodes, topic_length = outputs
                topics = None
                node_align_paras = None

        nodes = nodes.squeeze(0)

        # print('entity out:', entity_out)

        if not validate:
            greedy = self._net(enc_art,
                               nodes,
                               aligns=sent_aligns,
                               paras=(topics, node_align_paras, topic_length))
            samples = []
            probs = []
            sample, prob = self._net.sample(enc_art,
                                            nodes,
                                            aligns=sent_aligns,
                                            paras=(topics, node_align_paras,
                                                   topic_length))
            samples.append(sample)
            probs.append(prob)
        else:
            greedy = self._net(enc_art,
                               nodes,
                               aligns=sent_aligns,
                               paras=(topics, node_align_paras, topic_length))

            samples = []
            probs = []
            # for i in range(sample_time):
            #     sample, prob = self._net.sample(enc_art, nodes, aligns=sent_aligns, paras=(topics, node_align_paras, topic_length))
            #     samples.append(sample)
            #     probs.append(prob)
        return greedy, samples, probs
示例#5
0
def batchify_fn_graph_rl_bert(tokenizer,
                              data,
                              cuda=True,
                              adj_type='concat_triple',
                              docgraph=True,
                              reward_data_dir=None):
    start = tokenizer.encoder[tokenizer._bos_token]
    end = tokenizer.encoder[tokenizer._eos_token]
    pad = tokenizer.encoder[tokenizer._pad_token]
    unk = tokenizer.encoder[tokenizer._unk_token]
    if reward_data_dir is not None:
        batch, ext_word2id, raw_articles, raw_targets, questions = data
    else:
        batch, ext_word2id, raw_articles, raw_targets = data

        questions = []
    if docgraph:
        sources, ext_srcs, nodes, nodefreq, relations, triples, src_lens, tar_ins, targets = (
            batch[0], ) + tuple(map(list, unzip(batch[1])))
        if adj_type == 'concat_triple':
            adjs = [
                make_adj_triple(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        elif adj_type == 'edge_as_node':
            adjs = [
                make_adj_edge_in(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        else:
            adjs = [
                make_adj(triple, len(node), len(node), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
    else:
        sources, ext_srcs, nodes, nodefreq, relations, triples, node_lists, src_lens, tar_ins, targets = (
            batch[0], ) + tuple(map(list, unzip(batch[1])))
        if adj_type == 'edge_as_node':
            adjs = list(
                map(subgraph_make_adj_edge_in(cuda=cuda),
                    zip(triples, node_lists)))
        else:
            adjs = list(
                map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists)))

    nodefreq = pad_batch_tensorize(nodefreq, pad=0, cuda=cuda)
    feature_dict = {'node_freq': nodefreq}
    node_num = [len(_node) for _node in nodes]
    _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda)
    nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float()

    tar_ins = [[start] + tgt for tgt in tar_ins]
    targets = [tgt + [end] for tgt in targets]
    tar_in = pad_batch_tensorize(tar_ins, pad, cuda)
    target = pad_batch_tensorize(targets, pad, cuda)

    #src_lens = [len(src) for src in sources]
    sources = [src for src in sources]
    ext_srcs = [ext for ext in ext_srcs]

    source = pad_batch_tensorize(sources, pad, cuda)
    ext_src = pad_batch_tensorize(ext_srcs, pad, cuda)

    ext_vsize = ext_src.max().item() + 1
    extend_vsize = len(ext_word2id)
    ext_id2word = {_id: _word for _word, _id in ext_word2id.items()}
    #print('ext_size:', ext_vsize, extend_vsize)
    if docgraph:
        fw_args = (source, src_lens, ext_src, extend_vsize, _nodes, nmask,
                   node_num, feature_dict, adjs, start, end, unk, 150, tar_in)
    else:
        fw_args = (source, src_lens, ext_src, extend_vsize, _nodes, nmask,
                   node_num, feature_dict, node_lists, adjs, start, end, unk,
                   150, tar_in)

    loss_args = (raw_articles, ext_id2word, raw_targets, questions, target)

    return fw_args, loss_args