def __call__(self, raw_article_sents, raw_clusters): self._net.eval() n_art = len(raw_article_sents) articles = conver2id(UNK, self._word2id, raw_article_sents) clusters = (conver2id(UNK, self._word2id, raw_clusters[0]), raw_clusters[1], raw_clusters[2]) article = pad_batch_tensorize(articles, PAD, cuda=False, max_num=5).to(self._device) clusters = (pad_batch_tensorize(clusters[0], PAD, cuda=False, max_num=4).to(self._device), pad_batch_tensorize(clusters[1], PAD, cuda=False, max_num=4).to(self._device), pad_batch_tensorize(clusters[2], PAD, cuda=False, max_num=4).to(self._device)) if raw_clusters == []: print(clusters) indices = self._net.extract([article], clusters, k=min(n_art, self._max_ext)) return indices
def __call__(self, raw_article_sents, raw_query): self._net.eval() n_art = len(raw_article_sents) articles = conver2id(UNK, self._word2id, raw_article_sents) queries = conver2id(UNK, self._word2id, raw_query) article = pad_batch_tensorize(articles, PAD, cuda=False ).to(self._device) query = pad_batch_tensorize(queries, PAD, cuda=False ).to(self._device) indices = self._net.extract([article], k=min(n_art, self._max_ext), queries=[query]) return indices
def forward(self, raw_input, n_abs=None, sample_time=1, validate=False): raw_article_sents, raw_clusters = raw_input clusters = (self._batcher(raw_clusters[0]), pad_batch_tensorize(raw_clusters[1], pad=0, max_num=5), pad_batch_tensorize(raw_clusters[2], pad=0, max_num=5), torch.cuda.FloatTensor(raw_clusters[3]), torch.cuda.LongTensor(raw_clusters[4])) article_sent = self._batcher(raw_article_sents) enc_sent = self._sent_enc(article_sent).unsqueeze(0) enc_art = self._art_enc(enc_sent) # print('enc_Art:', enc_art) entity_out = self._encode_entity(clusters, cluster_nums=None) _, _, (entity_out, entity_mask) = self._graph_enc( [clusters[3]], [clusters[4]], (entity_out.unsqueeze(0), torch.tensor([len(raw_clusters[0])], device=entity_out.device))) entity_out = entity_out.squeeze(0) # print('entity out:', entity_out) if self.time_variant and not validate: greedy = self._net(enc_art, entity_out) samples = [] probs = [] sample, prob, new_greedy = self._net.sample( enc_art, entity_out, time_varient=self.time_variant) samples.append(sample) probs.append(prob) greedy = [greedy] + new_greedy if len(greedy) != len(prob): print(len(enc_art[0])) print(greedy) print(new_greedy) print(sample) print(prob) assert len(greedy) == len(prob) else: greedy = self._net(enc_art, entity_out) samples = [] probs = [] for i in range(sample_time): sample, prob = self._net.sample(enc_art, entity_out, time_varient=False) samples.append(sample) probs.append(prob) return greedy, samples, probs
def __call__(self, raw_article_sents): self._net.eval() n_art = len(raw_article_sents) articles = conver2id(UNK, self._word2id, raw_article_sents) article = pad_batch_tensorize(articles, PAD, cuda=False ).to(self._device) indices = self._net.extract([article], k=min(n_art, self._max_ext)) return indices
def _prepro(self, raw_article_sents): ext_word2id = dict(self._word2id) ext_id2word = dict(self._id2word) for raw_words in raw_article_sents: for w in raw_words: if not w in ext_word2id: ext_word2id[w] = len(ext_word2id) ext_id2word[len(ext_id2word)] = w articles = conver2id(UNK, self._word2id, raw_article_sents) art_lens = [len(art) for art in articles] article = pad_batch_tensorize(articles, PAD, cuda=False).to(self._device) extend_arts = conver2id(UNK, ext_word2id, raw_article_sents) extend_art = pad_batch_tensorize(extend_arts, PAD, cuda=False).to(self._device) extend_vsize = len(ext_word2id) dec_args = (article, art_lens, extend_art, extend_vsize, START, END, UNK, self._max_len) return dec_args, ext_id2word
def __call__(self, raw_article_sents, sent_labels): self._net.eval() n_art = len(raw_article_sents) articles = conver2id(UNK, self._word2id, raw_article_sents) article = pad_batch_tensorize(articles, PAD, cuda=False).to(self._device) indices = self._net.extract([article], [sent_labels]) return indices
def __call__(self, raw_article_sents): articles = [ self._tokenizer.encode(' '.join(words), add_special_tokens=True) for words in raw_article_sents ] articles = self.trim_length(articles) article = pad_batch_tensorize(articles, self._pad, cuda=False).to(self._device) return article
def __call__(self, raw_article_sents, raw_abs_sents=None): if self.net_type == 'ml_rnn_extractor': articles = conver2id(UNK, self._word2id, raw_article_sents) article = pad_batch_tensorize(articles, PAD, cuda=False ).to(self._device) elif self.net_type == 'ml_trans_rnn_extractor': # print([" ".join(r) for r in raw_article_sents]) # print([" ".join(r) for r in raw_abs_sents]) article = myextract.get_batch_trans([([" ".join(r) for r in raw_article_sents], [" ".join(r) for r in raw_abs_sents])]) return article
def _prepro(self, raw_article_sents): ext_word2id = dict(self._word2id) ext_id2word = dict(self._id2word) for raw_words in raw_article_sents: for w in raw_words: if not w in ext_word2id: ext_word2id[w] = len(ext_word2id) ext_id2word[len(ext_id2word)] = w articles = conver2id(UNK, self._word2id, raw_article_sents) art_lens = [len(art) for art in articles] article = pad_batch_tensorize(articles, PAD, cuda=False ).to(self._device) extend_arts = conver2id(UNK, ext_word2id, raw_article_sents) extend_art = pad_batch_tensorize(extend_arts, PAD, cuda=False ).to(self._device) extend_vsize = len(ext_word2id) dec_args = (article, art_lens, extend_art, extend_vsize, START, END, UNK, self._max_len) return dec_args, ext_id2word
def __call__(self, raw_article_sents): self._net.eval() n_art = len(raw_article_sents) articles = conver2id(UNK, self._word2id, raw_article_sents) article = pad_batch_tensorize(articles, PAD, cuda=False, max_num=5 ).to(self._device) if not self.force_ext: indices = self._net.extract([article], k=min(n_art, self._max_ext), force_ext=self.force_ext) else: indices = self._net.extract([article], k=min(n_art, self._max_ext)) return indices
def _prepro(self, raw_article_sents): ext_word2id = dict(self._word2id) ext_id2word = dict(self._id2word) articles = conver2id(UNK, self._word2id, raw_article_sents) art_lens = [len(art) for art in articles] article = pad_batch_tensorize(articles, PAD, cuda=False).to(self._device) dec_args = (article, art_lens, START, END, UNK, self._max_len) return dec_args, ext_id2word
def forward(self, raw_input, n_abs=None, sample_time=1, validate=False): raw_article_sents, raw_clusters = raw_input clusters = (self._batcher(raw_clusters[0]), pad_batch_tensorize(raw_clusters[1], pad=0, max_num=5), pad_batch_tensorize(raw_clusters[2], pad=0, max_num=5)) article_sent = self._batcher(raw_article_sents) enc_sent = self._sent_enc(article_sent).unsqueeze(0) enc_art = self._art_enc(enc_sent) # print('enc_Art:', enc_art) if not self._context: entity_out = self._encode_entity(clusters, cluster_nums=None) else: entity_out = self._encode_entity(clusters, cluster_nums=None, context=enc_art) # print('entity out:', entity_out) if self.time_variant and not validate: greedy = self._net(enc_art, entity_out) samples = [] probs = [] sample, prob, new_greedy = self._net.sample(enc_art, entity_out, time_varient=self.time_variant) samples.append(sample) probs.append(prob) greedy = [greedy] + new_greedy if len(greedy) != len(prob): print(len(enc_art[0])) print(greedy) print(new_greedy) print(sample) print(prob) assert len(greedy) == len(prob) else: greedy = self._net(enc_art, entity_out) samples = [] probs = [] for i in range(sample_time): sample, prob = self._net.sample(enc_art, entity_out, time_varient=False) samples.append(sample) probs.append(prob) return greedy, samples, probs
def __call__(self, raw_article_sents): self._net.eval() n_art = len(raw_article_sents) if self._emb_type == 'W2V': articles = conver2id(UNK, self._word2id, raw_article_sents) else: articles = [self._tokenizer.convert_tokens_to_ids(sentence) for sentence in raw_article_sents] article = pad_batch_tensorize(articles, PAD, cuda=False ).to(self._device) indices = self._net.extract([article], k=min(n_art, self._max_ext)) return indices
def __call__(self, raw_article_sents): self._net.eval() n_art = len(raw_article_sents) if self._net_type == 'ml_trans_rnn_extractor': n_art = len(raw_article_sents[0]) batch = myextract.get_batch_trans([raw_article_sents]) indices = self._net.extract(batch, k=min(n_art, self._max_ext)) return indices, batch else: articles = conver2id(UNK, self._word2id, raw_article_sents) article = pad_batch_tensorize(articles, PAD, cuda=False ).to(self._device) indices = self._net.extract([article], k=min(n_art, self._max_ext)) return indices
def __call__(self, raw_article_sents): tokenized_sents = raw_article_sents stride = 256 tokenized_sents_lists = [tokenized_sents[:BERT_MAX_LEN]] length = len(tokenized_sents) - BERT_MAX_LEN i = 1 while length > 0: tokenized_sents_lists.append( tokenized_sents[i * BERT_MAX_LEN - stride:(i + 1) * BERT_MAX_LEN - stride]) i += 1 length -= (BERT_MAX_LEN - stride) id_sents = [ self._tokenizer.convert_tokens_to_ids(tokenized_sents) for tokenized_sents in tokenized_sents_lists ] pad = self._tokenizer.encoder[self._tokenizer._pad_token] sources = pad_batch_tensorize(id_sents, pad=pad, cuda=False).to(self._device) return sources
def batchify_fn_subgraph_nobert(pad, data, cuda=True, adj_type='edge_as_node', mask_type='none', model_type='gat'): assert adj_type in [ 'no_edge', 'edge_up', 'edge_down', 'concat_triple', 'edge_as_node' ] source_lists, targets, source_articles, nodes, sum_worthy, relations, triples, node_lists, dec_selection_mask, \ sent_align_paras, segment_feat_sent, segment_feat_para, nodefreq, word_inpara_freq, sent_word_inpara_freq = tuple(map(list, unzip(data))) if adj_type == 'edge_as_node': batch_adjs = list( map(subgraph_make_adj_edge_in(cuda=cuda), zip(triples, node_lists))) else: batch_adjs = list( map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists))) # print('adj:', batch_adjs[0][0]) # print('node list:', node_lists[0][0]) # print('triple:', triples[0][0]) src_nums = [len(source_list) for source_list in source_lists] source_articles = pad_batch_tensorize(source_articles, pad=pad, cuda=cuda) segment_feat_para = pad_batch_tensorize(segment_feat_para, pad=pad, cuda=cuda) #sources = list(map(pad_batch_tensorize(pad=pad, cuda=cuda), source_lists)) sources = list( map(pad_batch_tensorize(pad=pad, cuda=cuda, max_num=5), source_lists)) segment_feat_sent = list( map(pad_batch_tensorize(pad=pad, cuda=cuda, max_num=5), segment_feat_sent)) nodefreq = pad_batch_tensorize(nodefreq, pad=pad, cuda=cuda) word_inpara_freq = pad_batch_tensorize(word_inpara_freq, pad=pad, cuda=cuda) sent_word_inpara_freq = list( map(pad_batch_tensorize(pad=pad, cuda=cuda, max_num=5), sent_word_inpara_freq)) # source_lists = [source for source_list in source_lists for source in source_list] # sources = pad_batch_tensorize(source_lists, pad=pad, cuda=cuda) #print('extracted labels:', extracted_labels) sum_worthy_label = pad_batch_tensorize(sum_worthy, pad=-1, cuda=cuda) sum_worthy = pad_batch_tensorize(sum_worthy, pad=0, cuda=cuda).float() dec_selection_mask = pad_batch_tensorize(dec_selection_mask, pad=0, cuda=cuda).float() node_num = [len(_node) for _node in nodes] _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda) _relations = pad_batch_tensorize_3d(relations, pad=0, cuda=cuda) nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float() rmask = pad_batch_tensorize_3d(relations, pad=-1, cuda=cuda).ne(-1).float() segment_features = pad_batch_tensorize # PAD is -1 (dummy extraction index) for using sequence loss target = pad_batch_tensorize(targets, pad=-1, cuda=cuda) remove_last = lambda tgt: tgt[:-1] tar_in = pad_batch_tensorize( list(map(remove_last, targets)), pad=-0, cuda=cuda # use 0 here for feeding first conv sentence repr. ) feature_dict = { 'seg_para': segment_feat_para, 'seg_sent': segment_feat_sent, 'sent_inpara_freq': sent_word_inpara_freq, 'word_inpara_freq': word_inpara_freq, 'node_freq': nodefreq } fw_args = (src_nums, tar_in, (sources, source_articles, feature_dict), (_nodes, nmask, node_num, sum_worthy, dec_selection_mask), (_relations, rmask, triples, batch_adjs, node_lists, sent_align_paras)) if 'soft' in mask_type: loss_args = (target, sum_worthy_label) # elif decoder_supervision: # loss_args = (target, extracted_labels) else: loss_args = (target, ) return fw_args, loss_args
def batchify_fn_graph_rl(pad, start, end, data, cuda=True, adj_type='concat_triple', docgraph=True, reward_data_dir=None): if reward_data_dir is not None: batch, ext_word2id, raw_articles, raw_targets, questions = data else: batch, ext_word2id, raw_articles, raw_targets = data questions = [] if docgraph: sources, ext_srcs, nodes, word_freq_feat, nodefreq, relations, triples = tuple( map(list, unzip(batch))) if adj_type == 'concat_triple': adjs = [ make_adj_triple(triple, len(node), len(relation), cuda) for triple, node, relation in zip(triples, nodes, relations) ] elif adj_type == 'edge_as_node': adjs = [ make_adj_edge_in(triple, len(node), len(relation), cuda) for triple, node, relation in zip(triples, nodes, relations) ] else: adjs = [ make_adj(triple, len(node), len(node), cuda) for triple, node, relation in zip(triples, nodes, relations) ] else: sources, ext_srcs, nodes, word_freq_feat, nodefreq, relations, triples, node_lists = tuple( map(list, unzip(batch))) if adj_type == 'edge_as_node': adjs = list( map(subgraph_make_adj_edge_in(cuda=cuda), zip(triples, node_lists))) else: adjs = list( map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists))) nodefreq = pad_batch_tensorize(nodefreq, pad=pad, cuda=cuda) word_freq = pad_batch_tensorize(word_freq_feat, pad=pad, cuda=cuda) feature_dict = {'word_inpara_freq': word_freq, 'node_freq': nodefreq} node_num = [len(_node) for _node in nodes] _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda) nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float() src_lens = [len(src) for src in sources] sources = [src for src in sources] ext_srcs = [ext for ext in ext_srcs] source = pad_batch_tensorize(sources, pad, cuda) ext_src = pad_batch_tensorize(ext_srcs, pad, cuda) ext_vsize = ext_src.max().item() + 1 extend_vsize = len(ext_word2id) ext_id2word = {_id: _word for _word, _id in ext_word2id.items()} #print('ext_size:', ext_vsize, extend_vsize) if docgraph: fw_args = (source, src_lens, ext_src, extend_vsize, _nodes, nmask, node_num, feature_dict, adjs, START, END, UNK, 100) else: fw_args = (source, src_lens, ext_src, extend_vsize, _nodes, nmask, node_num, feature_dict, node_lists, adjs, START, END, UNK, 100) loss_args = (raw_articles, ext_id2word, raw_targets, questions) return fw_args, loss_args
def batchify_fn_gat_copy_from_graph(pad, start, end, data, cuda=True, adj_type='concat_triple', mask_type='none', decoder_supervision=False): sources, ext_srcs, tar_ins, targets, \ nodes, nodelengths, sum_worthy, relations, rlengths, triples, \ all_node_words, ext_node_aligns, gold_copy_mask = tuple(map(list, unzip(data))) if adj_type == 'concat_triple': adjs = [ make_adj_triple(triple, len(node), len(relation), cuda) for triple, node, relation in zip(triples, nodes, relations) ] elif adj_type == 'edge_as_node': adjs = [ make_adj_edge_in(triple, len(node), len(relation), cuda) for triple, node, relation in zip(triples, nodes, relations) ] else: adjs = [ make_adj(triple, len(node), len(node), cuda) for triple, node, relation in zip(triples, nodes, relations) ] src_lens = [len(src) for src in sources] sources = [src for src in sources] ext_srcs = [ext for ext in ext_srcs] tar_ins = [[start] + tgt for tgt in tar_ins] targets = [tgt + [end] for tgt in targets] source = pad_batch_tensorize(sources, pad, cuda) tar_in = pad_batch_tensorize(tar_ins, pad, cuda) target = pad_batch_tensorize(targets, pad, cuda) ext_src = pad_batch_tensorize(ext_srcs, pad, cuda) all_node_word = pad_batch_tensorize(all_node_words, pad, cuda) all_node_mask = pad_batch_tensorize(all_node_words, pad=-1, cuda=cuda).ne(-1).float() ext_node_aligns = pad_batch_tensorize(ext_node_aligns, pad=0, cuda=cuda) gold_copy_mask = pad_batch_tensorize(gold_copy_mask, pad=0, cuda=cuda).float() sum_worthy_label = pad_batch_tensorize(sum_worthy, pad=-1, cuda=cuda) sum_worthy = pad_batch_tensorize(sum_worthy, pad=0, cuda=cuda).float() node_num = [len(_node) for _node in nodes] _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda) _relations = pad_batch_tensorize_3d(relations, pad=0, cuda=cuda) nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float() rmask = pad_batch_tensorize_3d(relations, pad=-1, cuda=cuda).ne(-1).float() ext_vsize = ext_src.max().item() + 1 fw_args = (source, src_lens, tar_in, ext_src, ext_vsize, (_nodes, nmask, node_num, sum_worthy), (_relations, rmask, triples, adjs), (all_node_word, all_node_mask, ext_node_aligns, gold_copy_mask)) if 'soft' in mask_type and decoder_supervision: raise Exception('not implemented yet') #loss_args = (target, sum_worthy_label, extracted_labels) elif 'soft' in mask_type: loss_args = (target, sum_worthy_label) elif decoder_supervision: raise Exception('not implemented yet') #loss_args = (target, extracted_labels) else: loss_args = (target, ) return fw_args, loss_args
def batchify_fn_gat(pad, start, end, data, cuda=True, adj_type='concat_triple', mask_type='none', decoder_supervision=False, docgraph=True): sources, ext_srcs, tar_ins, targets, \ nodes, nodelengths, sum_worthy, word_freq_feat, nodefreq, relations, rlengths, triples = tuple(map(list, unzip(data))) if not docgraph: node_lists = nodelengths if adj_type == 'edge_as_node': adjs = list( map(subgraph_make_adj_edge_in(cuda=cuda), zip(triples, node_lists))) else: adjs = list( map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists))) else: if adj_type == 'concat_triple': adjs = [ make_adj_triple(triple, len(node), len(relation), cuda) for triple, node, relation in zip(triples, nodes, relations) ] elif adj_type == 'edge_as_node': adjs = [ make_adj_edge_in(triple, len(node), len(relation), cuda) for triple, node, relation in zip(triples, nodes, relations) ] else: adjs = [ make_adj(triple, len(node), len(node), cuda) for triple, node, relation in zip(triples, nodes, relations) ] src_lens = [len(src) for src in sources] sources = [src for src in sources] ext_srcs = [ext for ext in ext_srcs] tar_ins = [[start] + tgt for tgt in tar_ins] targets = [tgt + [end] for tgt in targets] nodefreq = pad_batch_tensorize(nodefreq, pad=pad, cuda=cuda) word_freq = pad_batch_tensorize(word_freq_feat, pad=pad, cuda=cuda) feature_dict = {'word_inpara_freq': word_freq, 'node_freq': nodefreq} source = pad_batch_tensorize(sources, pad, cuda) tar_in = pad_batch_tensorize(tar_ins, pad, cuda) target = pad_batch_tensorize(targets, pad, cuda) ext_src = pad_batch_tensorize(ext_srcs, pad, cuda) sum_worthy_label = pad_batch_tensorize(sum_worthy, pad=-1, cuda=cuda) sum_worthy = pad_batch_tensorize(sum_worthy, pad=0, cuda=cuda).float() node_num = [len(_node) for _node in nodes] _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda) _relations = pad_batch_tensorize_3d(relations, pad=0, cuda=cuda) nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float() rmask = pad_batch_tensorize_3d(relations, pad=-1, cuda=cuda).ne(-1).float() ext_vsize = ext_src.max().item() + 1 if docgraph: fw_args = (source, src_lens, tar_in, ext_src, ext_vsize, (_nodes, nmask, node_num, sum_worthy, feature_dict), (_relations, rmask, triples, adjs)) else: fw_args = (source, src_lens, tar_in, ext_src, ext_vsize, (_nodes, nmask, node_num, sum_worthy, feature_dict, node_lists), (_relations, rmask, triples, adjs)) if 'soft' in mask_type and decoder_supervision: raise Exception('not implemented yet') #loss_args = (target, sum_worthy_label, extracted_labels) elif 'soft' in mask_type: loss_args = (target, sum_worthy_label) elif decoder_supervision: raise Exception('not implemented yet') #loss_args = (target, extracted_labels) else: loss_args = (target, ) return fw_args, loss_args
def forward(self, raw_input, n_abs=None, sample_time=1, validate=False): if self._enable_docgraph: if self._bert: raise NotImplementedError else: raw_article_sents, nodes, nodefreq, word_freq_feat, sent_word_freq, triples, relations, sent_aligns = raw_input if self._adj_type == 'concat_triple': adjs = [ make_adj_triple(triples, len(nodes), len(relations), self._cuda) ] elif self._adj_type == 'edge_as_node': adjs = [ make_adj_edge_in(triples, len(nodes), len(relations), self._cuda) ] else: adjs = [make_adj(triples, len(nodes), len(nodes), self._cuda)] else: if self._bert: _, raw_article_sents, nodes, nodefreq, triples, relations, node_lists, word_nums = raw_input else: raw_article_sents, nodes, nodefreq, word_freq_feat, sent_word_freq, triples, relations, sent_aligns, node_lists = raw_input if self._adj_type == 'edge_as_node': adjs = [ subgraph_make_adj_edge_in((triples, node_lists), cuda=self._cuda) ] else: adjs = [ subgraph_make_adj((triples, node_lists), cuda=self._cuda) ] if not self._bert: sent_word_freq = pad_batch_tensorize(sent_word_freq, pad=0, max_num=5, cuda=self._cuda) word_freq_feat = pad_batch_tensorize([word_freq_feat], pad=0, cuda=self._cuda) nodenum = [len(nodes)] sentnum = [len(raw_article_sents)] nmask = pad_batch_tensorize( nodes, pad=-1, cuda=self._cuda).ne(-1).float().unsqueeze(0) nodes = pad_batch_tensorize(nodes, pad=0, cuda=self._cuda).unsqueeze(0) nodefreq = pad_batch_tensorize([nodefreq], pad=0, cuda=self._cuda) if self._bert: articles = self._batcher(raw_article_sents) articles, article_sent = self._encode_bert(articles, [word_nums]) enc_sent = self._sent_enc(article_sent[0], None, None, None).unsqueeze(0) enc_art = self._art_enc(enc_sent) sent_aligns = None else: article_sent, articles = self._batcher(raw_article_sents) articles = self._sent_enc._embedding(articles) sent_aligns = [sent_aligns] if self._pe: bs, max_art_len, _ = articles.size() src_pos = torch.tensor([[i for i in range(max_art_len)] for _ in range(bs) ]).to(articles.device) src_pos = self._sent_enc.poisition_enc(src_pos) articles = torch.cat([articles, src_pos], dim=-1) if 'inpara_freq' in self._feature_banks: word_inpara_freq = self._sent_enc._inpara_embedding( word_freq_feat) articles = torch.cat([articles, word_inpara_freq], dim=-1) enc_sent = self._sent_enc(article_sent, None, sent_word_freq, None).unsqueeze(0) enc_art = self._art_enc(enc_sent) # print('enc_Art:', enc_art) if self._enable_docgraph: nodes = self._encode_docgraph(articles, nodes, nmask, adjs, nodenum, enc_out=enc_art, sent_nums=sentnum, nodefreq=nodefreq) else: outputs = self._encode_paragraph(articles, nodes, nmask, adjs, [node_lists], enc_out=enc_art, sent_nums=sentnum, nodefreq=nodefreq) if self._hierarchical_attn: (topics, topic_length), masks, (nodes, node_length, node_align_paras) = outputs node_align_paras = pad_batch_tensorize(node_align_paras, pad=0, cuda=False).to( nodes.device) elif 'soft' in self._mask_type: (nodes, topic_length), masks = outputs topics = None node_align_paras = None else: nodes, topic_length = outputs topics = None node_align_paras = None nodes = nodes.squeeze(0) # print('entity out:', entity_out) if not validate: greedy = self._net(enc_art, nodes, aligns=sent_aligns, paras=(topics, node_align_paras, topic_length)) samples = [] probs = [] sample, prob = self._net.sample(enc_art, nodes, aligns=sent_aligns, paras=(topics, node_align_paras, topic_length)) samples.append(sample) probs.append(prob) else: greedy = self._net(enc_art, nodes, aligns=sent_aligns, paras=(topics, node_align_paras, topic_length)) samples = [] probs = [] # for i in range(sample_time): # sample, prob = self._net.sample(enc_art, nodes, aligns=sent_aligns, paras=(topics, node_align_paras, topic_length)) # samples.append(sample) # probs.append(prob) return greedy, samples, probs
def batchify_fn_gat_bert(tokenizer, data, cuda=True, adj_type='concat_triple', mask_type='none', docgraph=True): sources, ext_srcs, tar_ins, targets, \ nodes, nodelengths, sum_worthy, nodefreq, relations, rlengths, triples, src_lens = (data[0], ) + tuple(map(list, unzip(data[1]))) start = tokenizer.encoder[tokenizer._bos_token] end = tokenizer.encoder[tokenizer._eos_token] pad = tokenizer.encoder[tokenizer._pad_token] if not docgraph: node_lists = nodelengths if adj_type == 'edge_as_node': adjs = list( map(subgraph_make_adj_edge_in(cuda=cuda), zip(triples, node_lists))) else: adjs = list( map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists))) else: if adj_type == 'concat_triple': adjs = [ make_adj_triple(triple, len(node), len(relation), cuda) for triple, node, relation in zip(triples, nodes, relations) ] elif adj_type == 'edge_as_node': adjs = [ make_adj_edge_in(triple, len(node), len(relation), cuda) for triple, node, relation in zip(triples, nodes, relations) ] else: adjs = [ make_adj(triple, len(node), len(node), cuda) for triple, node, relation in zip(triples, nodes, relations) ] #src_lens = [len(src) for src in sources] sources = [src for src in sources] ext_srcs = [ext for ext in ext_srcs] tar_ins = [[start] + tgt for tgt in tar_ins] targets = [tgt + [end] for tgt in targets] nodefreq = pad_batch_tensorize(nodefreq, pad=pad, cuda=cuda) feature_dict = {'node_freq': nodefreq} source = pad_batch_tensorize(sources, pad, cuda) tar_in = pad_batch_tensorize(tar_ins, pad, cuda) target = pad_batch_tensorize(targets, pad, cuda) ext_src = pad_batch_tensorize(ext_srcs, pad, cuda) sum_worthy_label = pad_batch_tensorize(sum_worthy, pad=-1, cuda=cuda) sum_worthy = pad_batch_tensorize(sum_worthy, pad=0, cuda=cuda).float() node_num = [len(_node) for _node in nodes] _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda) _relations = pad_batch_tensorize_3d(relations, pad=0, cuda=cuda) nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float() rmask = pad_batch_tensorize_3d(relations, pad=-1, cuda=cuda).ne(-1).float() ext_vsize = ext_src.max().item() + 1 if docgraph: fw_args = (source, src_lens, tar_in, ext_src, ext_vsize, (_nodes, nmask, node_num, sum_worthy, feature_dict), (_relations, rmask, triples, adjs)) else: fw_args = (source, src_lens, tar_in, ext_src, ext_vsize, (_nodes, nmask, node_num, sum_worthy, feature_dict, node_lists), (_relations, rmask, triples, adjs)) if 'soft' in mask_type: loss_args = (target, sum_worthy_label) else: loss_args = (target, ) return fw_args, loss_args
def __call__(self, raw_article_sents): articles = conver2id(UNK, self._word2id, raw_article_sents) article = pad_batch_tensorize(articles, PAD, cuda=False ).to(self._device) return article
def forward(self, raw_article_sents, n_abs=None, sample_time=1, validate=False): if self._bert: if self._bert_sent: _, article, word_num = raw_article_sents article = pad_batch_tensorize(article, pad=0, cuda=True) mask = (article != 0).detach().float() with torch.no_grad(): bert_out = self._bert_model(article) bert_hidden = torch.cat( [bert_out[-1][_] for _ in [-4, -3, -2, -1]], dim=-1) bert_hidden = self._bert_relu(self._bert_linear(bert_hidden)) bert_hidden = bert_hidden * mask.unsqueeze(2) article_sent = bert_hidden enc_sent = self._sent_enc(article_sent).unsqueeze(0) # print('enc_sent:', enc_sent) enc_art = self._art_enc(enc_sent) else: _, articles, word_num = raw_article_sents if self._bert_stride != 0: source_num = sum(word_num) articles = pad_batch_tensorize(articles, pad=0, cuda=True) else: articles = torch.tensor(articles, device='cuda') with torch.no_grad(): bert_out = self._bert_model(articles) bert_hidden = torch.cat( [bert_out[-1][_] for _ in [-4, -3, -2, -1]], dim=-1) bert_hidden = self._bert_relu(self._bert_linear(bert_hidden)) hsz = bert_hidden.size(2) if self._bert_stride != 0: batch_id = 0 source = torch.zeros(source_num, hsz).to(bert_hidden.device) if source_num < BERT_MAX_LEN: source[:source_num, :] += bert_hidden[ batch_id, :source_num, :] batch_id += 1 else: source[:BERT_MAX_LEN, :] += bert_hidden[ batch_id, :BERT_MAX_LEN, :] batch_id += 1 start = BERT_MAX_LEN while start < source_num: #print(start, source_num, max_source) if start - self._bert_stride + BERT_MAX_LEN < source_num: end = start - self._bert_stride + BERT_MAX_LEN batch_end = BERT_MAX_LEN else: end = source_num batch_end = source_num - start + self._bert_stride source[start:end, :] += bert_hidden[ batch_id, self._bert_stride:batch_end, :] batch_id += 1 start += (BERT_MAX_LEN - self._bert_stride) bert_hidden = source.unsqueeze(0) del source max_word_num = max(word_num) # if max_word_num < 5: # max_word_num = 5 new_word_num = [] start_num = 0 for num in word_num: new_word_num.append((start_num, start_num + num)) start_num += num article_sent = torch.stack([ torch.cat([ bert_hidden[0, num[0]:num[1], :], torch.zeros(max_word_num - num[1] + num[0], hsz).to(bert_hidden.device) ], dim=0) if (num[1] - num[0]) != max_word_num else bert_hidden[0, num[0]:num[1], :] for num in new_word_num ]) # print('article_sent:', article_sent) enc_sent = self._sent_enc(article_sent).unsqueeze(0) # print('enc_sent:', enc_sent) enc_art = self._art_enc(enc_sent) # print('enc_art:', enc_art) else: article_sent = self._batcher(raw_article_sents) enc_sent = self._sent_enc(article_sent).unsqueeze(0) enc_art = self._art_enc(enc_sent) if self.time_variant and not validate: greedy = self._net(enc_art) samples = [] probs = [] sample, prob, new_greedy = self._net.sample( enc_art, time_varient=self.time_variant) samples.append(sample) probs.append(prob) greedy = [greedy] + new_greedy if len(greedy) != len(prob): print(len(enc_art[0])) print(greedy) print(new_greedy) print(sample) print(prob) assert len(greedy) == len(prob) else: greedy = self._net(enc_art) samples = [] probs = [] for i in range(sample_time): sample, prob = self._net.sample(enc_art, time_varient=False) samples.append(sample) probs.append(prob) # print('samle:', samples) # print('greedy:', greedy) return greedy, samples, probs
def __call__(self, raw_article_sents): articles = conver2id(UNK, self._word2id, raw_article_sents) article = pad_batch_tensorize(articles, PAD, cuda=False).to(self._device) return article
def batchify_fn_graph_rl_bert(tokenizer, data, cuda=True, adj_type='concat_triple', docgraph=True, reward_data_dir=None): start = tokenizer.encoder[tokenizer._bos_token] end = tokenizer.encoder[tokenizer._eos_token] pad = tokenizer.encoder[tokenizer._pad_token] unk = tokenizer.encoder[tokenizer._unk_token] if reward_data_dir is not None: batch, ext_word2id, raw_articles, raw_targets, questions = data else: batch, ext_word2id, raw_articles, raw_targets = data questions = [] if docgraph: sources, ext_srcs, nodes, nodefreq, relations, triples, src_lens, tar_ins, targets = ( batch[0], ) + tuple(map(list, unzip(batch[1]))) if adj_type == 'concat_triple': adjs = [ make_adj_triple(triple, len(node), len(relation), cuda) for triple, node, relation in zip(triples, nodes, relations) ] elif adj_type == 'edge_as_node': adjs = [ make_adj_edge_in(triple, len(node), len(relation), cuda) for triple, node, relation in zip(triples, nodes, relations) ] else: adjs = [ make_adj(triple, len(node), len(node), cuda) for triple, node, relation in zip(triples, nodes, relations) ] else: sources, ext_srcs, nodes, nodefreq, relations, triples, node_lists, src_lens, tar_ins, targets = ( batch[0], ) + tuple(map(list, unzip(batch[1]))) if adj_type == 'edge_as_node': adjs = list( map(subgraph_make_adj_edge_in(cuda=cuda), zip(triples, node_lists))) else: adjs = list( map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists))) nodefreq = pad_batch_tensorize(nodefreq, pad=0, cuda=cuda) feature_dict = {'node_freq': nodefreq} node_num = [len(_node) for _node in nodes] _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda) nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float() tar_ins = [[start] + tgt for tgt in tar_ins] targets = [tgt + [end] for tgt in targets] tar_in = pad_batch_tensorize(tar_ins, pad, cuda) target = pad_batch_tensorize(targets, pad, cuda) #src_lens = [len(src) for src in sources] sources = [src for src in sources] ext_srcs = [ext for ext in ext_srcs] source = pad_batch_tensorize(sources, pad, cuda) ext_src = pad_batch_tensorize(ext_srcs, pad, cuda) ext_vsize = ext_src.max().item() + 1 extend_vsize = len(ext_word2id) ext_id2word = {_id: _word for _word, _id in ext_word2id.items()} #print('ext_size:', ext_vsize, extend_vsize) if docgraph: fw_args = (source, src_lens, ext_src, extend_vsize, _nodes, nmask, node_num, feature_dict, adjs, start, end, unk, 150, tar_in) else: fw_args = (source, src_lens, ext_src, extend_vsize, _nodes, nmask, node_num, feature_dict, node_lists, adjs, start, end, unk, 150, tar_in) loss_args = (raw_articles, ext_id2word, raw_targets, questions, target) return fw_args, loss_args