예제 #1
0
    def __call__(self, raw_article_sents, raw_clusters):
        self._net.eval()
        n_art = len(raw_article_sents)
        articles = conver2id(UNK, self._word2id, raw_article_sents)

        clusters = (conver2id(UNK, self._word2id, raw_clusters[0]),
                    raw_clusters[1], raw_clusters[2])
        article = pad_batch_tensorize(articles, PAD, cuda=False,
                                      max_num=5).to(self._device)
        clusters = (pad_batch_tensorize(clusters[0],
                                        PAD,
                                        cuda=False,
                                        max_num=4).to(self._device),
                    pad_batch_tensorize(clusters[1],
                                        PAD,
                                        cuda=False,
                                        max_num=4).to(self._device),
                    pad_batch_tensorize(clusters[2],
                                        PAD,
                                        cuda=False,
                                        max_num=4).to(self._device))
        if raw_clusters == []:
            print(clusters)

        indices = self._net.extract([article],
                                    clusters,
                                    k=min(n_art, self._max_ext))
        return indices
 def __call__(self, raw_article_sents, raw_query):
     self._net.eval()
     n_art = len(raw_article_sents)
     articles = conver2id(UNK, self._word2id, raw_article_sents)
     queries = conver2id(UNK, self._word2id, raw_query)
     article = pad_batch_tensorize(articles, PAD, cuda=False
                                  ).to(self._device)
     query = pad_batch_tensorize(queries, PAD, cuda=False
                                  ).to(self._device)
     indices = self._net.extract([article], k=min(n_art, self._max_ext), queries=[query])
     return indices
예제 #3
0
    def forward(self, raw_input, n_abs=None, sample_time=1, validate=False):
        raw_article_sents, raw_clusters = raw_input
        clusters = (self._batcher(raw_clusters[0]),
                    pad_batch_tensorize(raw_clusters[1], pad=0, max_num=5),
                    pad_batch_tensorize(raw_clusters[2], pad=0, max_num=5),
                    torch.cuda.FloatTensor(raw_clusters[3]),
                    torch.cuda.LongTensor(raw_clusters[4]))

        article_sent = self._batcher(raw_article_sents)
        enc_sent = self._sent_enc(article_sent).unsqueeze(0)
        enc_art = self._art_enc(enc_sent)
        # print('enc_Art:', enc_art)

        entity_out = self._encode_entity(clusters, cluster_nums=None)
        _, _, (entity_out, entity_mask) = self._graph_enc(
            [clusters[3]], [clusters[4]],
            (entity_out.unsqueeze(0),
             torch.tensor([len(raw_clusters[0])], device=entity_out.device)))
        entity_out = entity_out.squeeze(0)

        # print('entity out:', entity_out)

        if self.time_variant and not validate:
            greedy = self._net(enc_art, entity_out)
            samples = []
            probs = []
            sample, prob, new_greedy = self._net.sample(
                enc_art, entity_out, time_varient=self.time_variant)
            samples.append(sample)
            probs.append(prob)
            greedy = [greedy] + new_greedy
            if len(greedy) != len(prob):
                print(len(enc_art[0]))
                print(greedy)
                print(new_greedy)
                print(sample)
                print(prob)
            assert len(greedy) == len(prob)
        else:
            greedy = self._net(enc_art, entity_out)

            samples = []
            probs = []
            for i in range(sample_time):
                sample, prob = self._net.sample(enc_art,
                                                entity_out,
                                                time_varient=False)
                samples.append(sample)
                probs.append(prob)
        return greedy, samples, probs
예제 #4
0
 def __call__(self, raw_article_sents):
     self._net.eval()
     n_art = len(raw_article_sents)
     articles = conver2id(UNK, self._word2id, raw_article_sents)
     article = pad_batch_tensorize(articles, PAD, cuda=False
                                  ).to(self._device)
     indices = self._net.extract([article], k=min(n_art, self._max_ext))
     return indices
예제 #5
0
 def _prepro(self, raw_article_sents):
     ext_word2id = dict(self._word2id)
     ext_id2word = dict(self._id2word)
     for raw_words in raw_article_sents:
         for w in raw_words:
             if not w in ext_word2id:
                 ext_word2id[w] = len(ext_word2id)
                 ext_id2word[len(ext_id2word)] = w
     articles = conver2id(UNK, self._word2id, raw_article_sents)
     art_lens = [len(art) for art in articles]
     article = pad_batch_tensorize(articles, PAD, cuda=False).to(self._device)
     extend_arts = conver2id(UNK, ext_word2id, raw_article_sents)
     extend_art = pad_batch_tensorize(extend_arts, PAD, cuda=False).to(self._device)
     extend_vsize = len(ext_word2id)
     dec_args = (article, art_lens, extend_art, extend_vsize,
                 START, END, UNK, self._max_len)
     return dec_args, ext_id2word
예제 #6
0
 def __call__(self, raw_article_sents, sent_labels):
     self._net.eval()
     n_art = len(raw_article_sents)
     articles = conver2id(UNK, self._word2id, raw_article_sents)
     article = pad_batch_tensorize(articles, PAD,
                                   cuda=False).to(self._device)
     indices = self._net.extract([article], [sent_labels])
     return indices
예제 #7
0
 def __call__(self, raw_article_sents):
     articles = [
         self._tokenizer.encode(' '.join(words), add_special_tokens=True)
         for words in raw_article_sents
     ]
     articles = self.trim_length(articles)
     article = pad_batch_tensorize(articles, self._pad,
                                   cuda=False).to(self._device)
     return article
예제 #8
0
 def __call__(self, raw_article_sents, raw_abs_sents=None):
     if self.net_type == 'ml_rnn_extractor':
         articles = conver2id(UNK, self._word2id, raw_article_sents)
         article = pad_batch_tensorize(articles, PAD, cuda=False
                                     ).to(self._device)
     elif self.net_type == 'ml_trans_rnn_extractor':
         # print([" ".join(r) for r in raw_article_sents])
         # print([" ".join(r) for r in raw_abs_sents])
         article = myextract.get_batch_trans([([" ".join(r) for r in raw_article_sents], [" ".join(r) for r in raw_abs_sents])])
     return article
예제 #9
0
 def _prepro(self, raw_article_sents):
     ext_word2id = dict(self._word2id)
     ext_id2word = dict(self._id2word)
     for raw_words in raw_article_sents:
         for w in raw_words:
             if not w in ext_word2id:
                 ext_word2id[w] = len(ext_word2id)
                 ext_id2word[len(ext_id2word)] = w
     articles = conver2id(UNK, self._word2id, raw_article_sents)
     art_lens = [len(art) for art in articles]
     article = pad_batch_tensorize(articles, PAD, cuda=False
                                  ).to(self._device)
     extend_arts = conver2id(UNK, ext_word2id, raw_article_sents)
     extend_art = pad_batch_tensorize(extend_arts, PAD, cuda=False
                                     ).to(self._device)
     extend_vsize = len(ext_word2id)
     dec_args = (article, art_lens, extend_art, extend_vsize,
                 START, END, UNK, self._max_len)
     return dec_args, ext_id2word
예제 #10
0
 def __call__(self, raw_article_sents):
     self._net.eval()
     n_art = len(raw_article_sents)
     articles = conver2id(UNK, self._word2id, raw_article_sents)
     article = pad_batch_tensorize(articles, PAD, cuda=False, max_num=5
                                  ).to(self._device)
     if not self.force_ext:
         indices = self._net.extract([article], k=min(n_art, self._max_ext), force_ext=self.force_ext)
     else:
         indices = self._net.extract([article], k=min(n_art, self._max_ext))
     return indices
예제 #11
0
    def _prepro(self, raw_article_sents):
        ext_word2id = dict(self._word2id)
        ext_id2word = dict(self._id2word)

        articles = conver2id(UNK, self._word2id, raw_article_sents)
        art_lens = [len(art) for art in articles]
        article = pad_batch_tensorize(articles, PAD,
                                      cuda=False).to(self._device)

        dec_args = (article, art_lens, START, END, UNK, self._max_len)
        return dec_args, ext_id2word
예제 #12
0
    def forward(self, raw_input, n_abs=None, sample_time=1, validate=False):
        raw_article_sents, raw_clusters = raw_input
        clusters = (self._batcher(raw_clusters[0]),
                    pad_batch_tensorize(raw_clusters[1], pad=0, max_num=5),
                    pad_batch_tensorize(raw_clusters[2], pad=0, max_num=5))

        article_sent = self._batcher(raw_article_sents)
        enc_sent = self._sent_enc(article_sent).unsqueeze(0)
        enc_art = self._art_enc(enc_sent)
        # print('enc_Art:', enc_art)
        if not self._context:
            entity_out = self._encode_entity(clusters, cluster_nums=None)
        else:
            entity_out = self._encode_entity(clusters, cluster_nums=None, context=enc_art)
        # print('entity out:', entity_out)

        if self.time_variant and not validate:
            greedy = self._net(enc_art, entity_out)
            samples = []
            probs = []
            sample, prob, new_greedy = self._net.sample(enc_art, entity_out, time_varient=self.time_variant)
            samples.append(sample)
            probs.append(prob)
            greedy = [greedy] + new_greedy
            if len(greedy) != len(prob):
                print(len(enc_art[0]))
                print(greedy)
                print(new_greedy)
                print(sample)
                print(prob)
            assert len(greedy) == len(prob)
        else:
            greedy = self._net(enc_art, entity_out)
            samples = []
            probs = []
            for i in range(sample_time):
                sample, prob = self._net.sample(enc_art, entity_out, time_varient=False)
                samples.append(sample)
                probs.append(prob)

        return greedy, samples, probs
 def __call__(self, raw_article_sents):
     self._net.eval()
     n_art = len(raw_article_sents)
     if self._emb_type == 'W2V':
         articles = conver2id(UNK, self._word2id, raw_article_sents)
     else:
         articles = [self._tokenizer.convert_tokens_to_ids(sentence) 
                     for sentence in raw_article_sents]
     article = pad_batch_tensorize(articles, PAD, cuda=False
                                  ).to(self._device)
     indices = self._net.extract([article], k=min(n_art, self._max_ext))
     return indices
예제 #14
0
 def __call__(self, raw_article_sents):
     self._net.eval()
     n_art = len(raw_article_sents)
     if self._net_type == 'ml_trans_rnn_extractor':
         n_art = len(raw_article_sents[0])
         batch = myextract.get_batch_trans([raw_article_sents])
         indices = self._net.extract(batch, k=min(n_art, self._max_ext))
         return indices, batch
     else:
         articles = conver2id(UNK, self._word2id, raw_article_sents)
         article = pad_batch_tensorize(articles, PAD, cuda=False
                                     ).to(self._device)
         indices = self._net.extract([article], k=min(n_art, self._max_ext))
     return indices
예제 #15
0
    def __call__(self, raw_article_sents):
        tokenized_sents = raw_article_sents
        stride = 256
        tokenized_sents_lists = [tokenized_sents[:BERT_MAX_LEN]]
        length = len(tokenized_sents) - BERT_MAX_LEN
        i = 1
        while length > 0:
            tokenized_sents_lists.append(
                tokenized_sents[i * BERT_MAX_LEN -
                                stride:(i + 1) * BERT_MAX_LEN - stride])
            i += 1
            length -= (BERT_MAX_LEN - stride)
        id_sents = [
            self._tokenizer.convert_tokens_to_ids(tokenized_sents)
            for tokenized_sents in tokenized_sents_lists
        ]

        pad = self._tokenizer.encoder[self._tokenizer._pad_token]

        sources = pad_batch_tensorize(id_sents, pad=pad,
                                      cuda=False).to(self._device)

        return sources
예제 #16
0
def batchify_fn_subgraph_nobert(pad,
                                data,
                                cuda=True,
                                adj_type='edge_as_node',
                                mask_type='none',
                                model_type='gat'):
    assert adj_type in [
        'no_edge', 'edge_up', 'edge_down', 'concat_triple', 'edge_as_node'
    ]
    source_lists, targets, source_articles, nodes, sum_worthy, relations, triples, node_lists, dec_selection_mask, \
    sent_align_paras, segment_feat_sent, segment_feat_para, nodefreq, word_inpara_freq, sent_word_inpara_freq = tuple(map(list, unzip(data)))
    if adj_type == 'edge_as_node':
        batch_adjs = list(
            map(subgraph_make_adj_edge_in(cuda=cuda), zip(triples,
                                                          node_lists)))
    else:
        batch_adjs = list(
            map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists)))
    # print('adj:', batch_adjs[0][0])
    # print('node list:', node_lists[0][0])
    # print('triple:', triples[0][0])

    src_nums = [len(source_list) for source_list in source_lists]
    source_articles = pad_batch_tensorize(source_articles, pad=pad, cuda=cuda)
    segment_feat_para = pad_batch_tensorize(segment_feat_para,
                                            pad=pad,
                                            cuda=cuda)
    #sources = list(map(pad_batch_tensorize(pad=pad, cuda=cuda), source_lists))
    sources = list(
        map(pad_batch_tensorize(pad=pad, cuda=cuda, max_num=5), source_lists))
    segment_feat_sent = list(
        map(pad_batch_tensorize(pad=pad, cuda=cuda, max_num=5),
            segment_feat_sent))

    nodefreq = pad_batch_tensorize(nodefreq, pad=pad, cuda=cuda)
    word_inpara_freq = pad_batch_tensorize(word_inpara_freq,
                                           pad=pad,
                                           cuda=cuda)
    sent_word_inpara_freq = list(
        map(pad_batch_tensorize(pad=pad, cuda=cuda, max_num=5),
            sent_word_inpara_freq))
    # source_lists = [source for source_list in source_lists for source in source_list]
    # sources = pad_batch_tensorize(source_lists, pad=pad, cuda=cuda)
    #print('extracted labels:', extracted_labels)

    sum_worthy_label = pad_batch_tensorize(sum_worthy, pad=-1, cuda=cuda)
    sum_worthy = pad_batch_tensorize(sum_worthy, pad=0, cuda=cuda).float()

    dec_selection_mask = pad_batch_tensorize(dec_selection_mask,
                                             pad=0,
                                             cuda=cuda).float()

    node_num = [len(_node) for _node in nodes]
    _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda)
    _relations = pad_batch_tensorize_3d(relations, pad=0, cuda=cuda)
    nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float()
    rmask = pad_batch_tensorize_3d(relations, pad=-1, cuda=cuda).ne(-1).float()
    segment_features = pad_batch_tensorize

    # PAD is -1 (dummy extraction index) for using sequence loss
    target = pad_batch_tensorize(targets, pad=-1, cuda=cuda)
    remove_last = lambda tgt: tgt[:-1]
    tar_in = pad_batch_tensorize(
        list(map(remove_last, targets)),
        pad=-0,
        cuda=cuda  # use 0 here for feeding first conv sentence repr.
    )
    feature_dict = {
        'seg_para': segment_feat_para,
        'seg_sent': segment_feat_sent,
        'sent_inpara_freq': sent_word_inpara_freq,
        'word_inpara_freq': word_inpara_freq,
        'node_freq': nodefreq
    }

    fw_args = (src_nums, tar_in, (sources, source_articles, feature_dict),
               (_nodes, nmask, node_num, sum_worthy,
                dec_selection_mask), (_relations, rmask, triples, batch_adjs,
                                      node_lists, sent_align_paras))
    if 'soft' in mask_type:
        loss_args = (target, sum_worthy_label)
    # elif decoder_supervision:
    #     loss_args = (target, extracted_labels)
    else:
        loss_args = (target, )
    return fw_args, loss_args
예제 #17
0
def batchify_fn_graph_rl(pad,
                         start,
                         end,
                         data,
                         cuda=True,
                         adj_type='concat_triple',
                         docgraph=True,
                         reward_data_dir=None):
    if reward_data_dir is not None:
        batch, ext_word2id, raw_articles, raw_targets, questions = data
    else:
        batch, ext_word2id, raw_articles, raw_targets = data
        questions = []
    if docgraph:
        sources, ext_srcs, nodes, word_freq_feat, nodefreq, relations, triples = tuple(
            map(list, unzip(batch)))
        if adj_type == 'concat_triple':
            adjs = [
                make_adj_triple(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        elif adj_type == 'edge_as_node':
            adjs = [
                make_adj_edge_in(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        else:
            adjs = [
                make_adj(triple, len(node), len(node), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
    else:
        sources, ext_srcs, nodes, word_freq_feat, nodefreq, relations, triples, node_lists = tuple(
            map(list, unzip(batch)))
        if adj_type == 'edge_as_node':
            adjs = list(
                map(subgraph_make_adj_edge_in(cuda=cuda),
                    zip(triples, node_lists)))
        else:
            adjs = list(
                map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists)))

    nodefreq = pad_batch_tensorize(nodefreq, pad=pad, cuda=cuda)
    word_freq = pad_batch_tensorize(word_freq_feat, pad=pad, cuda=cuda)
    feature_dict = {'word_inpara_freq': word_freq, 'node_freq': nodefreq}
    node_num = [len(_node) for _node in nodes]
    _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda)
    nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float()

    src_lens = [len(src) for src in sources]
    sources = [src for src in sources]
    ext_srcs = [ext for ext in ext_srcs]

    source = pad_batch_tensorize(sources, pad, cuda)
    ext_src = pad_batch_tensorize(ext_srcs, pad, cuda)

    ext_vsize = ext_src.max().item() + 1
    extend_vsize = len(ext_word2id)
    ext_id2word = {_id: _word for _word, _id in ext_word2id.items()}
    #print('ext_size:', ext_vsize, extend_vsize)
    if docgraph:
        fw_args = (source, src_lens, ext_src, extend_vsize, _nodes, nmask,
                   node_num, feature_dict, adjs, START, END, UNK, 100)
    else:
        fw_args = (source, src_lens, ext_src, extend_vsize, _nodes, nmask,
                   node_num, feature_dict, node_lists, adjs, START, END, UNK,
                   100)

    loss_args = (raw_articles, ext_id2word, raw_targets, questions)

    return fw_args, loss_args
예제 #18
0
def batchify_fn_gat_copy_from_graph(pad,
                                    start,
                                    end,
                                    data,
                                    cuda=True,
                                    adj_type='concat_triple',
                                    mask_type='none',
                                    decoder_supervision=False):
    sources, ext_srcs, tar_ins, targets, \
    nodes, nodelengths, sum_worthy, relations, rlengths, triples, \
    all_node_words, ext_node_aligns, gold_copy_mask = tuple(map(list, unzip(data)))
    if adj_type == 'concat_triple':
        adjs = [
            make_adj_triple(triple, len(node), len(relation), cuda)
            for triple, node, relation in zip(triples, nodes, relations)
        ]
    elif adj_type == 'edge_as_node':
        adjs = [
            make_adj_edge_in(triple, len(node), len(relation), cuda)
            for triple, node, relation in zip(triples, nodes, relations)
        ]
    else:
        adjs = [
            make_adj(triple, len(node), len(node), cuda)
            for triple, node, relation in zip(triples, nodes, relations)
        ]

    src_lens = [len(src) for src in sources]
    sources = [src for src in sources]
    ext_srcs = [ext for ext in ext_srcs]

    tar_ins = [[start] + tgt for tgt in tar_ins]
    targets = [tgt + [end] for tgt in targets]

    source = pad_batch_tensorize(sources, pad, cuda)
    tar_in = pad_batch_tensorize(tar_ins, pad, cuda)
    target = pad_batch_tensorize(targets, pad, cuda)
    ext_src = pad_batch_tensorize(ext_srcs, pad, cuda)
    all_node_word = pad_batch_tensorize(all_node_words, pad, cuda)
    all_node_mask = pad_batch_tensorize(all_node_words, pad=-1,
                                        cuda=cuda).ne(-1).float()
    ext_node_aligns = pad_batch_tensorize(ext_node_aligns, pad=0, cuda=cuda)
    gold_copy_mask = pad_batch_tensorize(gold_copy_mask, pad=0,
                                         cuda=cuda).float()

    sum_worthy_label = pad_batch_tensorize(sum_worthy, pad=-1, cuda=cuda)
    sum_worthy = pad_batch_tensorize(sum_worthy, pad=0, cuda=cuda).float()

    node_num = [len(_node) for _node in nodes]
    _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda)
    _relations = pad_batch_tensorize_3d(relations, pad=0, cuda=cuda)
    nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float()
    rmask = pad_batch_tensorize_3d(relations, pad=-1, cuda=cuda).ne(-1).float()

    ext_vsize = ext_src.max().item() + 1
    fw_args = (source, src_lens, tar_in, ext_src, ext_vsize,
               (_nodes, nmask, node_num, sum_worthy), (_relations, rmask,
                                                       triples, adjs),
               (all_node_word, all_node_mask, ext_node_aligns, gold_copy_mask))
    if 'soft' in mask_type and decoder_supervision:
        raise Exception('not implemented yet')
        #loss_args = (target, sum_worthy_label, extracted_labels)
    elif 'soft' in mask_type:
        loss_args = (target, sum_worthy_label)
    elif decoder_supervision:
        raise Exception('not implemented yet')
        #loss_args = (target, extracted_labels)
    else:
        loss_args = (target, )
    return fw_args, loss_args
예제 #19
0
def batchify_fn_gat(pad,
                    start,
                    end,
                    data,
                    cuda=True,
                    adj_type='concat_triple',
                    mask_type='none',
                    decoder_supervision=False,
                    docgraph=True):
    sources, ext_srcs, tar_ins, targets, \
    nodes, nodelengths, sum_worthy, word_freq_feat, nodefreq, relations, rlengths, triples = tuple(map(list, unzip(data)))
    if not docgraph:
        node_lists = nodelengths
        if adj_type == 'edge_as_node':
            adjs = list(
                map(subgraph_make_adj_edge_in(cuda=cuda),
                    zip(triples, node_lists)))
        else:
            adjs = list(
                map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists)))
    else:
        if adj_type == 'concat_triple':
            adjs = [
                make_adj_triple(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        elif adj_type == 'edge_as_node':
            adjs = [
                make_adj_edge_in(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        else:
            adjs = [
                make_adj(triple, len(node), len(node), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]

    src_lens = [len(src) for src in sources]
    sources = [src for src in sources]
    ext_srcs = [ext for ext in ext_srcs]

    tar_ins = [[start] + tgt for tgt in tar_ins]
    targets = [tgt + [end] for tgt in targets]

    nodefreq = pad_batch_tensorize(nodefreq, pad=pad, cuda=cuda)
    word_freq = pad_batch_tensorize(word_freq_feat, pad=pad, cuda=cuda)
    feature_dict = {'word_inpara_freq': word_freq, 'node_freq': nodefreq}

    source = pad_batch_tensorize(sources, pad, cuda)
    tar_in = pad_batch_tensorize(tar_ins, pad, cuda)
    target = pad_batch_tensorize(targets, pad, cuda)
    ext_src = pad_batch_tensorize(ext_srcs, pad, cuda)

    sum_worthy_label = pad_batch_tensorize(sum_worthy, pad=-1, cuda=cuda)
    sum_worthy = pad_batch_tensorize(sum_worthy, pad=0, cuda=cuda).float()

    node_num = [len(_node) for _node in nodes]
    _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda)
    _relations = pad_batch_tensorize_3d(relations, pad=0, cuda=cuda)
    nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float()
    rmask = pad_batch_tensorize_3d(relations, pad=-1, cuda=cuda).ne(-1).float()

    ext_vsize = ext_src.max().item() + 1
    if docgraph:
        fw_args = (source, src_lens, tar_in, ext_src, ext_vsize,
                   (_nodes, nmask, node_num, sum_worthy,
                    feature_dict), (_relations, rmask, triples, adjs))
    else:
        fw_args = (source, src_lens, tar_in, ext_src, ext_vsize,
                   (_nodes, nmask, node_num, sum_worthy, feature_dict,
                    node_lists), (_relations, rmask, triples, adjs))
    if 'soft' in mask_type and decoder_supervision:
        raise Exception('not implemented yet')
        #loss_args = (target, sum_worthy_label, extracted_labels)
    elif 'soft' in mask_type:
        loss_args = (target, sum_worthy_label)
    elif decoder_supervision:
        raise Exception('not implemented yet')
        #loss_args = (target, extracted_labels)
    else:
        loss_args = (target, )
    return fw_args, loss_args
예제 #20
0
    def forward(self, raw_input, n_abs=None, sample_time=1, validate=False):

        if self._enable_docgraph:
            if self._bert:
                raise NotImplementedError
            else:
                raw_article_sents, nodes, nodefreq, word_freq_feat, sent_word_freq, triples, relations, sent_aligns = raw_input
            if self._adj_type == 'concat_triple':
                adjs = [
                    make_adj_triple(triples, len(nodes), len(relations),
                                    self._cuda)
                ]
            elif self._adj_type == 'edge_as_node':
                adjs = [
                    make_adj_edge_in(triples, len(nodes), len(relations),
                                     self._cuda)
                ]
            else:
                adjs = [make_adj(triples, len(nodes), len(nodes), self._cuda)]
        else:
            if self._bert:
                _, raw_article_sents, nodes, nodefreq, triples, relations, node_lists, word_nums = raw_input
            else:
                raw_article_sents, nodes, nodefreq, word_freq_feat, sent_word_freq, triples, relations, sent_aligns, node_lists = raw_input
            if self._adj_type == 'edge_as_node':
                adjs = [
                    subgraph_make_adj_edge_in((triples, node_lists),
                                              cuda=self._cuda)
                ]
            else:
                adjs = [
                    subgraph_make_adj((triples, node_lists), cuda=self._cuda)
                ]

        if not self._bert:
            sent_word_freq = pad_batch_tensorize(sent_word_freq,
                                                 pad=0,
                                                 max_num=5,
                                                 cuda=self._cuda)
            word_freq_feat = pad_batch_tensorize([word_freq_feat],
                                                 pad=0,
                                                 cuda=self._cuda)

        nodenum = [len(nodes)]
        sentnum = [len(raw_article_sents)]
        nmask = pad_batch_tensorize(
            nodes, pad=-1, cuda=self._cuda).ne(-1).float().unsqueeze(0)
        nodes = pad_batch_tensorize(nodes, pad=0, cuda=self._cuda).unsqueeze(0)
        nodefreq = pad_batch_tensorize([nodefreq], pad=0, cuda=self._cuda)

        if self._bert:
            articles = self._batcher(raw_article_sents)
            articles, article_sent = self._encode_bert(articles, [word_nums])
            enc_sent = self._sent_enc(article_sent[0], None, None,
                                      None).unsqueeze(0)
            enc_art = self._art_enc(enc_sent)
            sent_aligns = None

        else:
            article_sent, articles = self._batcher(raw_article_sents)
            articles = self._sent_enc._embedding(articles)
            sent_aligns = [sent_aligns]
            if self._pe:
                bs, max_art_len, _ = articles.size()
                src_pos = torch.tensor([[i for i in range(max_art_len)]
                                        for _ in range(bs)
                                        ]).to(articles.device)
                src_pos = self._sent_enc.poisition_enc(src_pos)
                articles = torch.cat([articles, src_pos], dim=-1)
            if 'inpara_freq' in self._feature_banks:
                word_inpara_freq = self._sent_enc._inpara_embedding(
                    word_freq_feat)
                articles = torch.cat([articles, word_inpara_freq], dim=-1)
            enc_sent = self._sent_enc(article_sent, None, sent_word_freq,
                                      None).unsqueeze(0)
            enc_art = self._art_enc(enc_sent)

        # print('enc_Art:', enc_art)
        if self._enable_docgraph:
            nodes = self._encode_docgraph(articles,
                                          nodes,
                                          nmask,
                                          adjs,
                                          nodenum,
                                          enc_out=enc_art,
                                          sent_nums=sentnum,
                                          nodefreq=nodefreq)
        else:
            outputs = self._encode_paragraph(articles,
                                             nodes,
                                             nmask,
                                             adjs, [node_lists],
                                             enc_out=enc_art,
                                             sent_nums=sentnum,
                                             nodefreq=nodefreq)
            if self._hierarchical_attn:
                (topics, topic_length), masks, (nodes, node_length,
                                                node_align_paras) = outputs
                node_align_paras = pad_batch_tensorize(node_align_paras,
                                                       pad=0,
                                                       cuda=False).to(
                                                           nodes.device)
            elif 'soft' in self._mask_type:
                (nodes, topic_length), masks = outputs
                topics = None
                node_align_paras = None
            else:
                nodes, topic_length = outputs
                topics = None
                node_align_paras = None

        nodes = nodes.squeeze(0)

        # print('entity out:', entity_out)

        if not validate:
            greedy = self._net(enc_art,
                               nodes,
                               aligns=sent_aligns,
                               paras=(topics, node_align_paras, topic_length))
            samples = []
            probs = []
            sample, prob = self._net.sample(enc_art,
                                            nodes,
                                            aligns=sent_aligns,
                                            paras=(topics, node_align_paras,
                                                   topic_length))
            samples.append(sample)
            probs.append(prob)
        else:
            greedy = self._net(enc_art,
                               nodes,
                               aligns=sent_aligns,
                               paras=(topics, node_align_paras, topic_length))

            samples = []
            probs = []
            # for i in range(sample_time):
            #     sample, prob = self._net.sample(enc_art, nodes, aligns=sent_aligns, paras=(topics, node_align_paras, topic_length))
            #     samples.append(sample)
            #     probs.append(prob)
        return greedy, samples, probs
예제 #21
0
def batchify_fn_gat_bert(tokenizer,
                         data,
                         cuda=True,
                         adj_type='concat_triple',
                         mask_type='none',
                         docgraph=True):
    sources, ext_srcs, tar_ins, targets, \
    nodes, nodelengths, sum_worthy, nodefreq, relations, rlengths, triples, src_lens = (data[0], ) + tuple(map(list, unzip(data[1])))
    start = tokenizer.encoder[tokenizer._bos_token]
    end = tokenizer.encoder[tokenizer._eos_token]
    pad = tokenizer.encoder[tokenizer._pad_token]

    if not docgraph:
        node_lists = nodelengths
        if adj_type == 'edge_as_node':
            adjs = list(
                map(subgraph_make_adj_edge_in(cuda=cuda),
                    zip(triples, node_lists)))
        else:
            adjs = list(
                map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists)))
    else:
        if adj_type == 'concat_triple':
            adjs = [
                make_adj_triple(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        elif adj_type == 'edge_as_node':
            adjs = [
                make_adj_edge_in(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        else:
            adjs = [
                make_adj(triple, len(node), len(node), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]

    #src_lens = [len(src) for src in sources]
    sources = [src for src in sources]
    ext_srcs = [ext for ext in ext_srcs]

    tar_ins = [[start] + tgt for tgt in tar_ins]
    targets = [tgt + [end] for tgt in targets]

    nodefreq = pad_batch_tensorize(nodefreq, pad=pad, cuda=cuda)
    feature_dict = {'node_freq': nodefreq}

    source = pad_batch_tensorize(sources, pad, cuda)
    tar_in = pad_batch_tensorize(tar_ins, pad, cuda)
    target = pad_batch_tensorize(targets, pad, cuda)
    ext_src = pad_batch_tensorize(ext_srcs, pad, cuda)

    sum_worthy_label = pad_batch_tensorize(sum_worthy, pad=-1, cuda=cuda)
    sum_worthy = pad_batch_tensorize(sum_worthy, pad=0, cuda=cuda).float()

    node_num = [len(_node) for _node in nodes]
    _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda)
    _relations = pad_batch_tensorize_3d(relations, pad=0, cuda=cuda)
    nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float()
    rmask = pad_batch_tensorize_3d(relations, pad=-1, cuda=cuda).ne(-1).float()

    ext_vsize = ext_src.max().item() + 1
    if docgraph:
        fw_args = (source, src_lens, tar_in, ext_src, ext_vsize,
                   (_nodes, nmask, node_num, sum_worthy,
                    feature_dict), (_relations, rmask, triples, adjs))
    else:
        fw_args = (source, src_lens, tar_in, ext_src, ext_vsize,
                   (_nodes, nmask, node_num, sum_worthy, feature_dict,
                    node_lists), (_relations, rmask, triples, adjs))

    if 'soft' in mask_type:
        loss_args = (target, sum_worthy_label)
    else:
        loss_args = (target, )
    return fw_args, loss_args
예제 #22
0
 def __call__(self, raw_article_sents):
     articles = conver2id(UNK, self._word2id, raw_article_sents)
     article = pad_batch_tensorize(articles, PAD, cuda=False
                                  ).to(self._device)
     return article
예제 #23
0
    def forward(self,
                raw_article_sents,
                n_abs=None,
                sample_time=1,
                validate=False):
        if self._bert:
            if self._bert_sent:
                _, article, word_num = raw_article_sents
                article = pad_batch_tensorize(article, pad=0, cuda=True)
                mask = (article != 0).detach().float()
                with torch.no_grad():
                    bert_out = self._bert_model(article)
                bert_hidden = torch.cat(
                    [bert_out[-1][_] for _ in [-4, -3, -2, -1]], dim=-1)
                bert_hidden = self._bert_relu(self._bert_linear(bert_hidden))
                bert_hidden = bert_hidden * mask.unsqueeze(2)
                article_sent = bert_hidden
                enc_sent = self._sent_enc(article_sent).unsqueeze(0)
                # print('enc_sent:', enc_sent)
                enc_art = self._art_enc(enc_sent)
            else:
                _, articles, word_num = raw_article_sents
                if self._bert_stride != 0:
                    source_num = sum(word_num)
                    articles = pad_batch_tensorize(articles, pad=0, cuda=True)
                else:
                    articles = torch.tensor(articles, device='cuda')
                with torch.no_grad():
                    bert_out = self._bert_model(articles)
                bert_hidden = torch.cat(
                    [bert_out[-1][_] for _ in [-4, -3, -2, -1]], dim=-1)
                bert_hidden = self._bert_relu(self._bert_linear(bert_hidden))
                hsz = bert_hidden.size(2)
                if self._bert_stride != 0:
                    batch_id = 0
                    source = torch.zeros(source_num,
                                         hsz).to(bert_hidden.device)
                    if source_num < BERT_MAX_LEN:
                        source[:source_num, :] += bert_hidden[
                            batch_id, :source_num, :]
                        batch_id += 1
                    else:
                        source[:BERT_MAX_LEN, :] += bert_hidden[
                            batch_id, :BERT_MAX_LEN, :]
                        batch_id += 1
                        start = BERT_MAX_LEN
                        while start < source_num:
                            #print(start, source_num, max_source)
                            if start - self._bert_stride + BERT_MAX_LEN < source_num:
                                end = start - self._bert_stride + BERT_MAX_LEN
                                batch_end = BERT_MAX_LEN
                            else:
                                end = source_num
                                batch_end = source_num - start + self._bert_stride
                            source[start:end, :] += bert_hidden[
                                batch_id, self._bert_stride:batch_end, :]
                            batch_id += 1
                            start += (BERT_MAX_LEN - self._bert_stride)
                    bert_hidden = source.unsqueeze(0)
                    del source
                max_word_num = max(word_num)
                # if max_word_num < 5:
                #     max_word_num = 5
                new_word_num = []
                start_num = 0
                for num in word_num:
                    new_word_num.append((start_num, start_num + num))
                    start_num += num
                article_sent = torch.stack([
                    torch.cat([
                        bert_hidden[0, num[0]:num[1], :],
                        torch.zeros(max_word_num - num[1] +
                                    num[0], hsz).to(bert_hidden.device)
                    ],
                              dim=0) if (num[1] - num[0]) != max_word_num else
                    bert_hidden[0, num[0]:num[1], :] for num in new_word_num
                ])
                # print('article_sent:', article_sent)
                enc_sent = self._sent_enc(article_sent).unsqueeze(0)
                # print('enc_sent:', enc_sent)
                enc_art = self._art_enc(enc_sent)
                # print('enc_art:', enc_art)

        else:
            article_sent = self._batcher(raw_article_sents)
            enc_sent = self._sent_enc(article_sent).unsqueeze(0)
            enc_art = self._art_enc(enc_sent)

        if self.time_variant and not validate:
            greedy = self._net(enc_art)
            samples = []
            probs = []
            sample, prob, new_greedy = self._net.sample(
                enc_art, time_varient=self.time_variant)
            samples.append(sample)
            probs.append(prob)
            greedy = [greedy] + new_greedy
            if len(greedy) != len(prob):
                print(len(enc_art[0]))
                print(greedy)
                print(new_greedy)
                print(sample)
                print(prob)
            assert len(greedy) == len(prob)
        else:
            greedy = self._net(enc_art)
            samples = []
            probs = []
            for i in range(sample_time):
                sample, prob = self._net.sample(enc_art, time_varient=False)
                samples.append(sample)
                probs.append(prob)
            # print('samle:', samples)
            # print('greedy:', greedy)

        return greedy, samples, probs
예제 #24
0
 def __call__(self, raw_article_sents):
     articles = conver2id(UNK, self._word2id, raw_article_sents)
     article = pad_batch_tensorize(articles, PAD,
                                   cuda=False).to(self._device)
     return article
예제 #25
0
def batchify_fn_graph_rl_bert(tokenizer,
                              data,
                              cuda=True,
                              adj_type='concat_triple',
                              docgraph=True,
                              reward_data_dir=None):
    start = tokenizer.encoder[tokenizer._bos_token]
    end = tokenizer.encoder[tokenizer._eos_token]
    pad = tokenizer.encoder[tokenizer._pad_token]
    unk = tokenizer.encoder[tokenizer._unk_token]
    if reward_data_dir is not None:
        batch, ext_word2id, raw_articles, raw_targets, questions = data
    else:
        batch, ext_word2id, raw_articles, raw_targets = data

        questions = []
    if docgraph:
        sources, ext_srcs, nodes, nodefreq, relations, triples, src_lens, tar_ins, targets = (
            batch[0], ) + tuple(map(list, unzip(batch[1])))
        if adj_type == 'concat_triple':
            adjs = [
                make_adj_triple(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        elif adj_type == 'edge_as_node':
            adjs = [
                make_adj_edge_in(triple, len(node), len(relation), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
        else:
            adjs = [
                make_adj(triple, len(node), len(node), cuda)
                for triple, node, relation in zip(triples, nodes, relations)
            ]
    else:
        sources, ext_srcs, nodes, nodefreq, relations, triples, node_lists, src_lens, tar_ins, targets = (
            batch[0], ) + tuple(map(list, unzip(batch[1])))
        if adj_type == 'edge_as_node':
            adjs = list(
                map(subgraph_make_adj_edge_in(cuda=cuda),
                    zip(triples, node_lists)))
        else:
            adjs = list(
                map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists)))

    nodefreq = pad_batch_tensorize(nodefreq, pad=0, cuda=cuda)
    feature_dict = {'node_freq': nodefreq}
    node_num = [len(_node) for _node in nodes]
    _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda)
    nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float()

    tar_ins = [[start] + tgt for tgt in tar_ins]
    targets = [tgt + [end] for tgt in targets]
    tar_in = pad_batch_tensorize(tar_ins, pad, cuda)
    target = pad_batch_tensorize(targets, pad, cuda)

    #src_lens = [len(src) for src in sources]
    sources = [src for src in sources]
    ext_srcs = [ext for ext in ext_srcs]

    source = pad_batch_tensorize(sources, pad, cuda)
    ext_src = pad_batch_tensorize(ext_srcs, pad, cuda)

    ext_vsize = ext_src.max().item() + 1
    extend_vsize = len(ext_word2id)
    ext_id2word = {_id: _word for _word, _id in ext_word2id.items()}
    #print('ext_size:', ext_vsize, extend_vsize)
    if docgraph:
        fw_args = (source, src_lens, ext_src, extend_vsize, _nodes, nmask,
                   node_num, feature_dict, adjs, start, end, unk, 150, tar_in)
    else:
        fw_args = (source, src_lens, ext_src, extend_vsize, _nodes, nmask,
                   node_num, feature_dict, node_lists, adjs, start, end, unk,
                   150, tar_in)

    loss_args = (raw_articles, ext_id2word, raw_targets, questions, target)

    return fw_args, loss_args