def batchify_fn_graph_rl(pad, start, end, data, cuda=True, adj_type='concat_triple', docgraph=True, reward_data_dir=None): if reward_data_dir is not None: batch, ext_word2id, raw_articles, raw_targets, questions = data else: batch, ext_word2id, raw_articles, raw_targets = data questions = [] if docgraph: sources, ext_srcs, nodes, word_freq_feat, nodefreq, relations, triples = tuple( map(list, unzip(batch))) if adj_type == 'concat_triple': adjs = [ make_adj_triple(triple, len(node), len(relation), cuda) for triple, node, relation in zip(triples, nodes, relations) ] elif adj_type == 'edge_as_node': adjs = [ make_adj_edge_in(triple, len(node), len(relation), cuda) for triple, node, relation in zip(triples, nodes, relations) ] else: adjs = [ make_adj(triple, len(node), len(node), cuda) for triple, node, relation in zip(triples, nodes, relations) ] else: sources, ext_srcs, nodes, word_freq_feat, nodefreq, relations, triples, node_lists = tuple( map(list, unzip(batch))) if adj_type == 'edge_as_node': adjs = list( map(subgraph_make_adj_edge_in(cuda=cuda), zip(triples, node_lists))) else: adjs = list( map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists))) nodefreq = pad_batch_tensorize(nodefreq, pad=pad, cuda=cuda) word_freq = pad_batch_tensorize(word_freq_feat, pad=pad, cuda=cuda) feature_dict = {'word_inpara_freq': word_freq, 'node_freq': nodefreq} node_num = [len(_node) for _node in nodes] _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda) nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float() src_lens = [len(src) for src in sources] sources = [src for src in sources] ext_srcs = [ext for ext in ext_srcs] source = pad_batch_tensorize(sources, pad, cuda) ext_src = pad_batch_tensorize(ext_srcs, pad, cuda) ext_vsize = ext_src.max().item() + 1 extend_vsize = len(ext_word2id) ext_id2word = {_id: _word for _word, _id in ext_word2id.items()} #print('ext_size:', ext_vsize, extend_vsize) if docgraph: fw_args = (source, src_lens, ext_src, extend_vsize, _nodes, nmask, node_num, feature_dict, adjs, START, END, UNK, 100) else: fw_args = (source, src_lens, ext_src, extend_vsize, _nodes, nmask, node_num, feature_dict, node_lists, adjs, START, END, UNK, 100) loss_args = (raw_articles, ext_id2word, raw_targets, questions) return fw_args, loss_args
def batchify_fn_gat(pad, start, end, data, cuda=True, adj_type='concat_triple', mask_type='none', decoder_supervision=False, docgraph=True): sources, ext_srcs, tar_ins, targets, \ nodes, nodelengths, sum_worthy, word_freq_feat, nodefreq, relations, rlengths, triples = tuple(map(list, unzip(data))) if not docgraph: node_lists = nodelengths if adj_type == 'edge_as_node': adjs = list( map(subgraph_make_adj_edge_in(cuda=cuda), zip(triples, node_lists))) else: adjs = list( map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists))) else: if adj_type == 'concat_triple': adjs = [ make_adj_triple(triple, len(node), len(relation), cuda) for triple, node, relation in zip(triples, nodes, relations) ] elif adj_type == 'edge_as_node': adjs = [ make_adj_edge_in(triple, len(node), len(relation), cuda) for triple, node, relation in zip(triples, nodes, relations) ] else: adjs = [ make_adj(triple, len(node), len(node), cuda) for triple, node, relation in zip(triples, nodes, relations) ] src_lens = [len(src) for src in sources] sources = [src for src in sources] ext_srcs = [ext for ext in ext_srcs] tar_ins = [[start] + tgt for tgt in tar_ins] targets = [tgt + [end] for tgt in targets] nodefreq = pad_batch_tensorize(nodefreq, pad=pad, cuda=cuda) word_freq = pad_batch_tensorize(word_freq_feat, pad=pad, cuda=cuda) feature_dict = {'word_inpara_freq': word_freq, 'node_freq': nodefreq} source = pad_batch_tensorize(sources, pad, cuda) tar_in = pad_batch_tensorize(tar_ins, pad, cuda) target = pad_batch_tensorize(targets, pad, cuda) ext_src = pad_batch_tensorize(ext_srcs, pad, cuda) sum_worthy_label = pad_batch_tensorize(sum_worthy, pad=-1, cuda=cuda) sum_worthy = pad_batch_tensorize(sum_worthy, pad=0, cuda=cuda).float() node_num = [len(_node) for _node in nodes] _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda) _relations = pad_batch_tensorize_3d(relations, pad=0, cuda=cuda) nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float() rmask = pad_batch_tensorize_3d(relations, pad=-1, cuda=cuda).ne(-1).float() ext_vsize = ext_src.max().item() + 1 if docgraph: fw_args = (source, src_lens, tar_in, ext_src, ext_vsize, (_nodes, nmask, node_num, sum_worthy, feature_dict), (_relations, rmask, triples, adjs)) else: fw_args = (source, src_lens, tar_in, ext_src, ext_vsize, (_nodes, nmask, node_num, sum_worthy, feature_dict, node_lists), (_relations, rmask, triples, adjs)) if 'soft' in mask_type and decoder_supervision: raise Exception('not implemented yet') #loss_args = (target, sum_worthy_label, extracted_labels) elif 'soft' in mask_type: loss_args = (target, sum_worthy_label) elif decoder_supervision: raise Exception('not implemented yet') #loss_args = (target, extracted_labels) else: loss_args = (target, ) return fw_args, loss_args
def batchify_fn_gat_bert(tokenizer, data, cuda=True, adj_type='concat_triple', mask_type='none', docgraph=True): sources, ext_srcs, tar_ins, targets, \ nodes, nodelengths, sum_worthy, nodefreq, relations, rlengths, triples, src_lens = (data[0], ) + tuple(map(list, unzip(data[1]))) start = tokenizer.encoder[tokenizer._bos_token] end = tokenizer.encoder[tokenizer._eos_token] pad = tokenizer.encoder[tokenizer._pad_token] if not docgraph: node_lists = nodelengths if adj_type == 'edge_as_node': adjs = list( map(subgraph_make_adj_edge_in(cuda=cuda), zip(triples, node_lists))) else: adjs = list( map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists))) else: if adj_type == 'concat_triple': adjs = [ make_adj_triple(triple, len(node), len(relation), cuda) for triple, node, relation in zip(triples, nodes, relations) ] elif adj_type == 'edge_as_node': adjs = [ make_adj_edge_in(triple, len(node), len(relation), cuda) for triple, node, relation in zip(triples, nodes, relations) ] else: adjs = [ make_adj(triple, len(node), len(node), cuda) for triple, node, relation in zip(triples, nodes, relations) ] #src_lens = [len(src) for src in sources] sources = [src for src in sources] ext_srcs = [ext for ext in ext_srcs] tar_ins = [[start] + tgt for tgt in tar_ins] targets = [tgt + [end] for tgt in targets] nodefreq = pad_batch_tensorize(nodefreq, pad=pad, cuda=cuda) feature_dict = {'node_freq': nodefreq} source = pad_batch_tensorize(sources, pad, cuda) tar_in = pad_batch_tensorize(tar_ins, pad, cuda) target = pad_batch_tensorize(targets, pad, cuda) ext_src = pad_batch_tensorize(ext_srcs, pad, cuda) sum_worthy_label = pad_batch_tensorize(sum_worthy, pad=-1, cuda=cuda) sum_worthy = pad_batch_tensorize(sum_worthy, pad=0, cuda=cuda).float() node_num = [len(_node) for _node in nodes] _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda) _relations = pad_batch_tensorize_3d(relations, pad=0, cuda=cuda) nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float() rmask = pad_batch_tensorize_3d(relations, pad=-1, cuda=cuda).ne(-1).float() ext_vsize = ext_src.max().item() + 1 if docgraph: fw_args = (source, src_lens, tar_in, ext_src, ext_vsize, (_nodes, nmask, node_num, sum_worthy, feature_dict), (_relations, rmask, triples, adjs)) else: fw_args = (source, src_lens, tar_in, ext_src, ext_vsize, (_nodes, nmask, node_num, sum_worthy, feature_dict, node_lists), (_relations, rmask, triples, adjs)) if 'soft' in mask_type: loss_args = (target, sum_worthy_label) else: loss_args = (target, ) return fw_args, loss_args
def forward(self, raw_input, n_abs=None, sample_time=1, validate=False): if self._enable_docgraph: if self._bert: raise NotImplementedError else: raw_article_sents, nodes, nodefreq, word_freq_feat, sent_word_freq, triples, relations, sent_aligns = raw_input if self._adj_type == 'concat_triple': adjs = [ make_adj_triple(triples, len(nodes), len(relations), self._cuda) ] elif self._adj_type == 'edge_as_node': adjs = [ make_adj_edge_in(triples, len(nodes), len(relations), self._cuda) ] else: adjs = [make_adj(triples, len(nodes), len(nodes), self._cuda)] else: if self._bert: _, raw_article_sents, nodes, nodefreq, triples, relations, node_lists, word_nums = raw_input else: raw_article_sents, nodes, nodefreq, word_freq_feat, sent_word_freq, triples, relations, sent_aligns, node_lists = raw_input if self._adj_type == 'edge_as_node': adjs = [ subgraph_make_adj_edge_in((triples, node_lists), cuda=self._cuda) ] else: adjs = [ subgraph_make_adj((triples, node_lists), cuda=self._cuda) ] if not self._bert: sent_word_freq = pad_batch_tensorize(sent_word_freq, pad=0, max_num=5, cuda=self._cuda) word_freq_feat = pad_batch_tensorize([word_freq_feat], pad=0, cuda=self._cuda) nodenum = [len(nodes)] sentnum = [len(raw_article_sents)] nmask = pad_batch_tensorize( nodes, pad=-1, cuda=self._cuda).ne(-1).float().unsqueeze(0) nodes = pad_batch_tensorize(nodes, pad=0, cuda=self._cuda).unsqueeze(0) nodefreq = pad_batch_tensorize([nodefreq], pad=0, cuda=self._cuda) if self._bert: articles = self._batcher(raw_article_sents) articles, article_sent = self._encode_bert(articles, [word_nums]) enc_sent = self._sent_enc(article_sent[0], None, None, None).unsqueeze(0) enc_art = self._art_enc(enc_sent) sent_aligns = None else: article_sent, articles = self._batcher(raw_article_sents) articles = self._sent_enc._embedding(articles) sent_aligns = [sent_aligns] if self._pe: bs, max_art_len, _ = articles.size() src_pos = torch.tensor([[i for i in range(max_art_len)] for _ in range(bs) ]).to(articles.device) src_pos = self._sent_enc.poisition_enc(src_pos) articles = torch.cat([articles, src_pos], dim=-1) if 'inpara_freq' in self._feature_banks: word_inpara_freq = self._sent_enc._inpara_embedding( word_freq_feat) articles = torch.cat([articles, word_inpara_freq], dim=-1) enc_sent = self._sent_enc(article_sent, None, sent_word_freq, None).unsqueeze(0) enc_art = self._art_enc(enc_sent) # print('enc_Art:', enc_art) if self._enable_docgraph: nodes = self._encode_docgraph(articles, nodes, nmask, adjs, nodenum, enc_out=enc_art, sent_nums=sentnum, nodefreq=nodefreq) else: outputs = self._encode_paragraph(articles, nodes, nmask, adjs, [node_lists], enc_out=enc_art, sent_nums=sentnum, nodefreq=nodefreq) if self._hierarchical_attn: (topics, topic_length), masks, (nodes, node_length, node_align_paras) = outputs node_align_paras = pad_batch_tensorize(node_align_paras, pad=0, cuda=False).to( nodes.device) elif 'soft' in self._mask_type: (nodes, topic_length), masks = outputs topics = None node_align_paras = None else: nodes, topic_length = outputs topics = None node_align_paras = None nodes = nodes.squeeze(0) # print('entity out:', entity_out) if not validate: greedy = self._net(enc_art, nodes, aligns=sent_aligns, paras=(topics, node_align_paras, topic_length)) samples = [] probs = [] sample, prob = self._net.sample(enc_art, nodes, aligns=sent_aligns, paras=(topics, node_align_paras, topic_length)) samples.append(sample) probs.append(prob) else: greedy = self._net(enc_art, nodes, aligns=sent_aligns, paras=(topics, node_align_paras, topic_length)) samples = [] probs = [] # for i in range(sample_time): # sample, prob = self._net.sample(enc_art, nodes, aligns=sent_aligns, paras=(topics, node_align_paras, topic_length)) # samples.append(sample) # probs.append(prob) return greedy, samples, probs
def batchify_fn_graph_rl_bert(tokenizer, data, cuda=True, adj_type='concat_triple', docgraph=True, reward_data_dir=None): start = tokenizer.encoder[tokenizer._bos_token] end = tokenizer.encoder[tokenizer._eos_token] pad = tokenizer.encoder[tokenizer._pad_token] unk = tokenizer.encoder[tokenizer._unk_token] if reward_data_dir is not None: batch, ext_word2id, raw_articles, raw_targets, questions = data else: batch, ext_word2id, raw_articles, raw_targets = data questions = [] if docgraph: sources, ext_srcs, nodes, nodefreq, relations, triples, src_lens, tar_ins, targets = ( batch[0], ) + tuple(map(list, unzip(batch[1]))) if adj_type == 'concat_triple': adjs = [ make_adj_triple(triple, len(node), len(relation), cuda) for triple, node, relation in zip(triples, nodes, relations) ] elif adj_type == 'edge_as_node': adjs = [ make_adj_edge_in(triple, len(node), len(relation), cuda) for triple, node, relation in zip(triples, nodes, relations) ] else: adjs = [ make_adj(triple, len(node), len(node), cuda) for triple, node, relation in zip(triples, nodes, relations) ] else: sources, ext_srcs, nodes, nodefreq, relations, triples, node_lists, src_lens, tar_ins, targets = ( batch[0], ) + tuple(map(list, unzip(batch[1]))) if adj_type == 'edge_as_node': adjs = list( map(subgraph_make_adj_edge_in(cuda=cuda), zip(triples, node_lists))) else: adjs = list( map(subgraph_make_adj(cuda=cuda), zip(triples, node_lists))) nodefreq = pad_batch_tensorize(nodefreq, pad=0, cuda=cuda) feature_dict = {'node_freq': nodefreq} node_num = [len(_node) for _node in nodes] _nodes = pad_batch_tensorize_3d(nodes, pad=0, cuda=cuda) nmask = pad_batch_tensorize_3d(nodes, pad=-1, cuda=cuda).ne(-1).float() tar_ins = [[start] + tgt for tgt in tar_ins] targets = [tgt + [end] for tgt in targets] tar_in = pad_batch_tensorize(tar_ins, pad, cuda) target = pad_batch_tensorize(targets, pad, cuda) #src_lens = [len(src) for src in sources] sources = [src for src in sources] ext_srcs = [ext for ext in ext_srcs] source = pad_batch_tensorize(sources, pad, cuda) ext_src = pad_batch_tensorize(ext_srcs, pad, cuda) ext_vsize = ext_src.max().item() + 1 extend_vsize = len(ext_word2id) ext_id2word = {_id: _word for _word, _id in ext_word2id.items()} #print('ext_size:', ext_vsize, extend_vsize) if docgraph: fw_args = (source, src_lens, ext_src, extend_vsize, _nodes, nmask, node_num, feature_dict, adjs, start, end, unk, 150, tar_in) else: fw_args = (source, src_lens, ext_src, extend_vsize, _nodes, nmask, node_num, feature_dict, node_lists, adjs, start, end, unk, 150, tar_in) loss_args = (raw_articles, ext_id2word, raw_targets, questions, target) return fw_args, loss_args