def forward(self, _, heads, deps): '''heads.data: mypacked amr_l x rel_dim deps.data: mydoublepacked amr_l x amr_l x rel_dim ''' heads_data = heads.data deps_data = deps.data head_bilinear_transformed = self.bilinear( heads_data) #all_data x ( n_rel x inputsize) head_bias_unpacked = myunpack(self.head_bias(heads_data), heads.lengths) #[len x n_rel] size = deps_data.size() dep_bias = self.dep_bias(deps_data.view(-1, size[-1])).view( size[0], size[1], -1) dep_bias_unpacked, length_pairs = mydoubleunpack( MyDoublePackedSequence(MyPackedSequence(dep_bias, deps[0][1]), deps[1], dep_bias)) #[len x n_rel] bilinear_unpacked = myunpack(head_bilinear_transformed, heads.lengths) deps_unpacked, length_pairs = mydoubleunpack(deps) output, l = self.bilinearForParallel( zip(bilinear_unpacked, deps_unpacked, head_bias_unpacked, dep_bias_unpacked), length_pairs) myscore_packed = mypack(output, l) # prob_packed = MyPackedSequence(myscore_packed.data,l) return myscore_packed
def root_score(self,mypackedhead): heads = myunpack(*mypackedhead) output = [] for head in heads: score = self.root(head).squeeze(1) output.append(self.LogSoftmax(score)) return output
def forward(self, input, index, src_enc): """ # input: relBatch: packed_gold_amr_lengths x n_feature, AMR_CAT, AMR_LE, AMR_NER, AMR_SENSE, index of nodes, # mypacked_seq[packed_batch_gold_amr_len x tgt_feature_dim] # index: rel_index_batch: list(batch, real_gold_amr_len), but content is the index of recatogrized amr index, is a mapping # src_enc: DoublePackedSequence(packed: packed_g_amr_len x re_amr_len x dim), length= re_lens) """ assert isinstance(input, MyPackedSequence), input # lengths: real_gold_amr_lens # after unpack, input is packed_gold_amr_lengths x n_features input, lengths = input if self.alpha_dropout and self.training: input = data_dropout(input, self.alpha_dropout) psd_target_pos_embed = self.psd_target_pos_lut(input[:, PSD_POS]) #psd_sense_embed = self.psd_sense_lut(input[:,PSD_SENSE]) #psd_lemma_embed = self.lemma_lut(input[:,PSD_LE]) #psd_emb = torch.cat([psd_target_pos_embed, psd_sense_embed,psd_lemma_embed],1) #psd_emb = torch.cat([psd_target_pos_embed, psd_sense_embed],1) psd_emb = torch.cat([psd_target_pos_embed], 1) # head_emb_t : MyPackedSequence(data: packed_real_g_amr_l x dim), g_amr_l # dep_emb_t : MyDoublePackedSequence(PackedSequenceLength(packed_real_g_amr_l x real_g_amr_l x dim), length_pairs) # length_pairs :(g_amr_l, g_amr_l) head_emb_t, dep_emb_t, length_pairs = self.getEmb( index, src_enc) #packed, mydoublepacked head_emb = torch.cat([psd_emb, head_emb_t.data], 1) dep_psd_emb_t = myunpack(*MyPackedSequence(psd_emb, lengths)) dep_psd_emb = [ emb.unsqueeze(0).expand(emb.size(0), emb.size(0), emb.size(-1)) for emb in dep_psd_emb_t ] mydouble_psd_emb = mydoublepack(dep_psd_emb, length_pairs) dep_emb = torch.cat([mydouble_psd_emb.data, dep_emb_t.data], -1) # emb_unpacked = myunpack(emb,lengths) assert head_emb.size( -1) == self.inputSize, "wrong head size {}".format( head_emb.size()) # head_packed : MyPackedSequence(data: packed_real_g_amr_l x rel_dim), g_amr_l head_packed = MyPackedSequence(self.head(head_emb), lengths) # total,rel_dim head_psd_packed = MyPackedSequence(psd_emb, lengths) # total,rel_dim size = dep_emb.size() assert dep_emb.size(-1) == self.inputSize, "wrong dep size {}".format( dep_emb.size()) dep = self.dep(dep_emb.view(-1, size[-1])).view(size[0], size[1], -1) # dep_emb_t : MyDoublePackedSequence(PackedSequenceLength(packed_real_g_amr_l x real_g_amr_l x rel_dim), length_pairs) dep_packed = MyDoublePackedSequence( MyPackedSequence(dep, mydouble_psd_emb[0][1]), mydouble_psd_emb[1], dep) return head_psd_packed, head_packed, dep_packed #,MyPackedSequence(emb,lengths)
def posteriorIndictedEmb(embs, posterior): """ embs: actually are bert encodings, padded_batch_src_len x bert_output_dim posterior : mypack(packed_gold_amr_len x src_len), gold_amr_l """ # after unpack, embs :[max_src_len x batch x bert_output_dim] embs, src_len = unpack(embs) if isinstance(posterior, MyPackedSequence): # after unpack, posterior: list(batch_size, real_gold_amr_len x src_len) posterior = myunpack(*posterior) # after tranpose, batch_size x padd_src_len x dim embs = embs.transpose(0, 1) out = [] lengths = [] # real gold amr length for each example in the batch amr_len = [len(p) for p in posterior] for i, emb in enumerate(embs): # expanded_emb: real_gold_amr_len x src_len x dim expanded_emb = emb.unsqueeze(0).expand([amr_len[i]] + [i for i in emb.size()]) indicator = posterior[i].unsqueeze( 2) # real_gold_amr_len x src_len x 1 out.append( torch.cat([expanded_emb, indicator], 2)) # real_gold_amr_len x src_len x (dim+1) # out.append(expanded_emb) # real_gold_amr_len x src_len x (dim+1) lengths = lengths + [src_len[i]] * amr_len[ i] # lenghs =list(batch, real_gold_amr_len]), all stores src_lens data = torch.cat( out, dim=0) # packed_real_gold_amr_len x src_len x (dim+1) return pack(data, lengths, batch_first=True), amr_len elif isinstance(posterior, list): # real alignments embs = embs.transpose(0, 1) src_l = embs.size(1) amr_len = [len(i) for i in posterior] out = [] lengths = [] for i, emb in enumerate(embs): amr_l = len(posterior[i]) expanded_emb = emb.unsqueeze(0).expand( [amr_l] + [i for i in emb.size()]) # amr_len x src_len x dim indicator = emb.new_zeros((amr_l, src_l)) scattered_indicator = indicator.scatter( 1, posterior[i].unsqueeze(1), 1.0) # amr_len x src_len x 1 out.append( torch.cat([expanded_emb, scattered_indicator.unsqueeze(2)], 2)) # amr_len x src_len x (dim+1) #out.append(expanded_emb) # amr_len x src_len x (dim+1) lengths = lengths + [src_len[i] ] * amr_l # batch x amr_len, src_len data = torch.cat(out, dim=0) # batch x amr_len, x src_len x (dim +1) return pack(data, lengths, batch_first=True), amr_len
def getEmb(self,indexes,src_enc): head_emb,lengths = [],[] src_enc = myunpack(*src_enc) # pre_amr_l/src_l x batch x dim for i, index in enumerate(indexes): enc = src_enc[i] #src_l x dim head_emb.append(enc[index]) #var(amr_l x dim) lengths.append(len(index)) return mypack(head_emb,lengths)
def getEmb(self, indexes, src_enc): """ # index: rel_index_batch: list(batch, real_gold_amr_len), but content is the index of recatogrized amr index, is a mapping # src_enc: is weighted root_src_enc, MyPackedSequence(data: packed_re_amr_len x txt_enc_size, lengtgs: re_amr_lens) """ head_emb, lengths = [], [] src_enc = myunpack( *src_enc) # list(batch, real_re_amr_len x src_enc_size) for i, index in enumerate( indexes): # indexse, batch_size, real_gold_amr_lens enc = src_enc[i] #real_re_amr_len x src_enc_size, head_emb.append( enc[index] ) #the content of index if real_re_amd_index, list(batch, real_gold_amr_len, dim) lengths.append(len(index)) #list(batch, real_gold_amr_len) return mypack( head_emb, lengths ) # MyPackedSequence(data: packed_gold_amr_len x dim, lengths: readl_gold_amr_len)
def posteriorIndictedEmb(self,embs,posterior): #real alignment is sent in as list of index #variational relaxed posterior is sent in as MyPackedSequence #out (batch x amr_len) x src_len x (dim+1) embs,src_len = unpack(embs) if isinstance(posterior,MyPackedSequence): # print ("posterior is packed") posterior = myunpack(*posterior) embs = embs.transpose(0,1) out = [] lengths = [] amr_len = [len(p) for p in posterior] for i,emb in enumerate(embs): expanded_emb = emb.unsqueeze(0).expand([amr_len[i]]+[i for i in emb.size()]) # amr_len x src_len x dim indicator = posterior[i].unsqueeze(2) # amr_len x src_len x 1 out.append(torch.cat([expanded_emb,indicator],2)) # amr_len x src_len x (dim+1) lengths = lengths + [src_len[i]]*amr_len[i] data = torch.cat(out,dim=0) return pack(data,lengths,batch_first=True),amr_len elif isinstance(posterior,list): embs = embs.transpose(0,1) src_l = embs.size(1) amr_len = [len(i) for i in posterior] out = [] lengths = [] for i,emb in enumerate(embs): amr_l = len(posterior[i]) expanded_emb = emb.unsqueeze(0).expand([amr_l]+[i for i in emb.size()]) # amr_len x src_len x dim indicator = emb.data.new(amr_l,src_l).zero_() indicator.scatter_(1, posterior[i].data.unsqueeze(1), 1.0) # amr_len x src_len x 1 indicator = Variable(indicator.unsqueeze(2)) out.append(torch.cat([expanded_emb,indicator],2)) # amr_len x src_len x (dim+1) lengths = lengths + [src_len[i]]*amr_l data = torch.cat(out,dim=0) return pack(data,lengths,batch_first=True),amr_len
def forward(self, input, index,src_enc): assert isinstance(input, MyPackedSequence),input input,lengths = input if self.alpha and self.training: input = data_dropout(input,self.alpha) cat_embed = self.cat_lut(input[:,AMR_CAT]) lemma_embed = self.lemma_lut(input[:,AMR_LE]) amr_emb = torch.cat([cat_embed,lemma_embed],1) # print (input,lengths) head_emb_t,dep_emb_t,length_pairs = self.getEmb(index,src_enc) #packed, mydoublepacked head_emb = torch.cat([amr_emb,head_emb_t.data],1) dep_amr_emb_t = myunpack(*MyPackedSequence(amr_emb,lengths)) dep_amr_emb = [ emb.unsqueeze(0).expand(emb.size(0),emb.size(0),emb.size(-1)) for emb in dep_amr_emb_t] mydouble_amr_emb = mydoublepack(dep_amr_emb,length_pairs) # print ("rel_encoder",mydouble_amr_emb.data.size(),dep_emb_t.data.size()) dep_emb = torch.cat([mydouble_amr_emb.data,dep_emb_t.data],-1) # emb_unpacked = myunpack(emb,lengths) head_packed = MyPackedSequence(self.head(head_emb),lengths) # total,rel_dim head_amr_packed = MyPackedSequence(amr_emb,lengths) # total,rel_dim # print ("dep_emb",dep_emb.size()) size = dep_emb.size() dep = self.dep(dep_emb.view(-1,size[-1])).view(size[0],size[1],-1) dep_packed = MyDoublePackedSequence(MyPackedSequence(dep,mydouble_amr_emb[0][1]),mydouble_amr_emb[1],dep) return head_amr_packed,head_packed,dep_packed #,MyPackedSequence(emb,lengths)
def relProbAndConToGraph(self, srl_batch, srl_prob, roots, appended, get_sense=False, set_wiki=False): #batch of id #max_indices len,batch, n_feature #srcBatch batch source #out batch AMRuniversal rel_dict = self.dicts["rel_dict"] def get_uni_var(concepts, id): assert id < len(concepts), (id, concepts) uni = concepts[id] if uni.le in ["i", "it", "you", "they", "he", "she" ] and uni.cat == Rule_Concept: return uni, Var(uni.le) le = uni.le if uni.cat != Rule_String: uni.le = uni.le.strip("/").strip(":") if ":" in uni.le or "/" in uni.le: uni.cat = Rule_String if uni.le == "": return uni, Var(le + str(id)) return uni, Var(uni.le[0] + str(id)) def create_connected_graph(role_scores, concepts, root_id, dependent, aligns): #role_scores: amr x amr x rel graph = nx.DiGraph() n = len(concepts) role_scores = role_scores.view(n, n, -1).data max_non_score, max_non_score_id = role_scores[:, :, 1:].max(-1) max_non_score_id = max_non_score_id + 1 non_score = role_scores[:, :, 0] active_cost = non_score - max_non_score #so lowest cost edge gets to active first candidates = [] h_vs = [] for h_id in range(n): h, h_v = get_uni_var(concepts, h_id) h_vs.append(h_v) graph.add_node(h_v, value=h, align=aligns[h_id], gold=True, dep=dependent[h_id]) constant_links = {} normal_edged_links = {} for h_id in range(n): for d_id in range(n): if h_id != d_id: r = rel_dict.getLabel(max_non_score_id[h_id, d_id]) r_inver = r + "-of" if not r.endswith( "-of") else r[:-3] h, h_v = get_uni_var(concepts, h_id) d, d_v = get_uni_var(concepts, d_id) if (concepts[h_id].is_constant() or concepts[d_id].is_constant()): if concepts[h_id].is_constant( ) and concepts[d_id].is_constant(): continue elif concepts[h_id].is_constant(): constant_links.setdefault(h_v, []).append( (active_cost[h_id, d_id], d_v, r, r_inver)) else: constant_links.setdefault(d_v, []).append( (active_cost[h_id, d_id], h_v, r_inver, r)) elif active_cost[h_id, d_id] < 0: if r in [":name-of", ":ARG0" ] and concepts[d_id].le in ["person"]: graph.add_edge(h_v, d_v, role=r) graph.add_edge(d_v, h_v, role=r_inver) else: # if concepts[h_id].le == "name" and r != ":name-of": # r = ":name-of" normal_edged_links.setdefault( (h_v, r), []).append( (active_cost[h_id, d_id], d_v, r_inver)) else: candidates.append( (active_cost[h_id, d_id], (h_v, d_v, r, r_inver))) max_edge_per_node = 1 if not self.training else 100 for h_v, r in normal_edged_links: sorted_list = sorted(normal_edged_links[h_v, r], key=lambda j: j[0]) for _, d_v, r_inver in sorted_list[:max_edge_per_node]: # if graph.has_edge(h_v, d_v): # continue graph.add_edge(h_v, d_v, role=r) graph.add_edge(d_v, h_v, role=r_inver) for cost, d_v, r_inver in sorted_list[ max_edge_per_node:]: #remaining candidates.append((cost, (h_v, d_v, r, r_inver))) for h_v in constant_links: _, d_v, r, r_inver = sorted(constant_links[h_v], key=lambda j: j[0])[0] graph.add_edge(h_v, d_v, role=r) graph.add_edge(d_v, h_v, role=r_inver) candidates = sorted(candidates, key=lambda j: j[0]) for _, (h_v, d_v, r, r_inver) in candidates: if nx.is_strongly_connected(graph): break if not nx.has_path(graph, h_v, d_v): graph.add_edge(h_v, d_v, role=r, force_connect=True) graph.add_edge(d_v, h_v, role=r_inver, force_connect=True) _, root_v = get_uni_var(concepts, root_id) h_v = BOS_WORD root_symbol = AMRUniversal(BOS_WORD, BOS_WORD, NULL_WORD) graph.add_node(h_v, value=root_symbol, align=-1, gold=True, dep=1) graph.add_edge(h_v, root_v, role=":top") graph.add_edge(root_v, h_v, role=":top-of") if get_sense: for n, d in graph.nodes(True): if "value" not in d: print(n, d, graph[n], constant_links, graph.nodes, graph.edges) le, cat, sense = d["value"].le, d["value"].cat, d[ "value"].sense if cat == Rule_Frame and sense == "": sense = self.fragment_to_node_converter.get_senses(le) d["value"] = AMRUniversal(le, cat, sense) # if not self.training: # assert nx.is_strongly_connected(graph),("before contraction",self.graph_to_quadruples(graph),graph_to_amr(graph)) # graph = contract_graph(graph) # assert nx.is_strongly_connected(graph),("before contraction",self.graph_to_quadruples(graph),graph_to_amr(graph)) if set_wiki: list = [[n, d] for n, d in graph.nodes(True)] for n, d in list: for nearb in graph[n]: r = graph[n][nearb]["role"] if d["value"].le == "name": names = [] head = None for nearb in graph[n]: r = graph[n][nearb]["role"] if ":op" in r and "-of" not in r and int( graph[n][nearb]["role"][3:]) not in names: names.append([ graph.node[nearb]["value"], int(graph[n][nearb]["role"][3:]) ]) if r == ":name-of": wikied = False for nearbb in graph[nearb]: r = graph[nearb][nearbb]["role"] if r == ":wiki": wikied = True break if not wikied: head = nearb if head: names = tuple([ t[0] for t in sorted(names, key=lambda t: t[1]) ]) wiki = self.fragment_to_node_converter.get_wiki( names) # print (wiki) wiki_v = Var(wiki.le + n._name) graph.add_node(wiki_v, value=wiki, align=d["align"], gold=True, dep=2) #second order dependency graph.add_edge(head, wiki_v, role=":wiki") graph.add_edge(wiki_v, head, role=":wiki-of") assert nx.is_strongly_connected(graph), ( "after contraction", self.graph_to_quadruples(graph), graph_to_amr(graph)) return graph, self.graph_to_quadruples(graph) graphs = [] quadruple_batch = [] score_batch = myunpack(*srl_prob) #list of (h x d) depedent_mark_batch = appended[0] aligns_batch = appended[1] for i, (role_scores, concepts, roots_score, dependent_mark, aligns) in enumerate( zip(score_batch, srl_batch, roots, depedent_mark_batch, aligns_batch)): root_s, root_id = roots_score.max(0) assert roots_score.size(0) == len(concepts), (concepts, roots_score) root_id = root_id.data.tolist()[0] assert root_id < len(concepts), (concepts, roots_score) g, quadruples = create_connected_graph(role_scores, concepts, root_id, dependent_mark, aligns) graphs.append(g) quadruple_batch.append(quadruples) return graphs, quadruple_batch
def relProbAndConToGraph(self,srl_batch, sourceBatch, srl_prob,roots,appended,get_sense=False,set_wiki=False,normalizeMod=False): """ For EDS, there is normalizeMod """ # TODO: graph connected graph given relation prob #batch of id #max_indices len,batch, n_feature #srcBatch batch source #out batch AMRuniversal rel_dict = self.dicts["eds_rel_dict"] def get_uni_var(concepts,id): """ for eds, id is also not important, just use the id as variable value """ return concepts[id],EDSVar(str(id)) def create_connected_graph(role_scores,concepts,root_id,dependent,aligns,source): #role_scores: amr x amr x rel graph = nx.MultiDiGraph() n = len(concepts) role_scores = role_scores.view(n,n,-1) max_non_score, max_non_score_id= role_scores[:,:,1:].max(-1) max_non_score_id = max_non_score_id +1 non_score = role_scores[:,:,0] active_cost = non_score - max_non_score #so lowest cost edge gets to active first candidates = [] # add all nodes for h_id in range(n): h,h_v = get_uni_var(concepts,h_id) graph.add_node(h_v, value=h, align=aligns[h_id],gold=True,dep = dependent[h_id]) constant_links = {} normal_edged_links = {} # add all pairs of edges for h_id in range(n): for d_id in range(n): if h_id != d_id: r = rel_dict.getLabel(max_non_score_id[h_id,d_id].item()) # relations in rel_dict are forward argx, opx, sntx top # and all backward relations. # normalize mod should already be done in rel_dict # we should normalize it here when connecting, make sure to not consider the same relation twices r_inver = EDSGraph.get_inversed_edge(r) h,h_v = get_uni_var(concepts,h_id) d,d_v = get_uni_var(concepts,d_id) # TODO: for mwe, compound ner if active_cost[h_id,d_id] < 0: normal_edged_links.setdefault((h_v,r),[]).append((active_cost[h_id,d_id],d_v,r_inver)) else: candidates.append((active_cost[h_id,d_id],(h_v,d_v,r,r_inver))) max_edge_per_node = 1 if not self.training else 100 for h_v,r in normal_edged_links: sorted_list = sorted(normal_edged_links[h_v,r],key = lambda j:j[0]) for _,d_v,r_inver in sorted_list[:max_edge_per_node]: # if graph.has_edge(h_v, d_v): # continue graph.add_edge(h_v, d_v, key=r, role=r) graph.add_edge(d_v, h_v, key=r_inver, role=r_inver) for cost,d_v,r_inver in sorted_list[max_edge_per_node:]: #remaining candidates.append((cost,(h_v,d_v,r,r_inver))) candidates = sorted(candidates,key = lambda j:j[0]) for _,(h_v,d_v,r,r_inver ) in candidates: if nx.is_strongly_connected(graph): break if not nx.has_path(graph,h_v,d_v): graph.add_edge(h_v, d_v, key=r, role=r,force_connect=True) graph.add_edge(d_v, h_v, key=r_inver, role=r_inver,force_connect=True) _,root_v = get_uni_var(concepts,root_id) h_v = BOS_WORD root_symbol = EDSUniversal.TOP_EDSUniversal() graph.add_node(h_v, value=root_symbol, align=-1,gold=True,dep=1) graph.add_edge(h_v, root_v, key=":top", role=":top") graph.add_edge(root_v, h_v, key=":top-of", role=":top-of") for n,d in graph.nodes(True): le,pos,cat,sense,anchors = d["value"].le,d["value"].pos,d["value"].cat,d["value"].sense,d["value"].anchors # here align is the token index list if get_sense: if sense == "" or sense == None: sense = self.fragment_to_node_converter.get_senses(le, pos) anchors = [] # convert token ids into character offset, with input source data tok_index = d["align"] if tok_index >= 0 and tok_index < len(source[ANCHOR_IND_SOURCE_BATCH]): anchors.extend(source[ANCHOR_IND_SOURCE_BATCH][tok_index]) d["value"] = EDSUniversal(pos,cat,sense,le,anchors) d["anchors"] = anchors if not nx.is_strongly_connected(graph): logger.warn("not connected after contraction: %s, %s, %s, %s, %s, %s, %s, %s",self.graph_to_quadruples(graph),graph_to_amr(graph), candidates, constant_links, graph.nodes(), graph.edges(), normal_edged_links, concepts) return graph,self.graph_to_quadruples(graph) graphs = [] quadruple_batch = [] score_batch = myunpack(*srl_prob) #list of (h x d) depedent_mark_batch = appended[0] aligns_batch = appended[1] for i,(role_scores,concepts,roots_score,dependent_mark,aligns,source) in enumerate(zip(score_batch,srl_batch,roots,depedent_mark_batch,aligns_batch,sourceBatch)): # here we assuming, there is one root root_s,root_id = roots_score.max(0) assert roots_score.size(0) == len(concepts),(concepts,roots_score) # in pytorch 0.4.0 to https://pytorch.org/docs/stable/tensors.html#torch.Tensor.tolist # tensor.tolist can be a int, when root_id is a scalar root_id = root_id.item() assert root_id < len(concepts),(concepts,roots_score) g,quadruples = create_connected_graph(role_scores,concepts,root_id,dependent_mark,aligns,source) graphs.append(g) quadruple_batch.append(quadruples) return graphs,quadruple_batch
def relProbAndConToGraph(self,srl_batch, sourceBatch, srl_prob,roots,appended,get_sense=False,set_wiki=False,normalizeMod=False): #batch of id #max_indices len,batch, n_feature #srcBatch batch source #out batch AMRuniversal amr_rel_dict = self.dicts["amr_rel_dict"] def get_uni_var(concepts,id): assert id < len(concepts),(id,concepts) uni = concepts[id] if uni.le in [ "i" ,"it","you","they","he","she"] and uni.cat == Rule_Concept: return uni,Var(uni.le ) le = uni.le if uni.cat != Rule_String: uni.le = uni.le.strip("/").strip(":") if ":" in uni.le or "/" in uni.le: uni.cat = Rule_String if uni.le == "": return uni,Var(le+ str(id)) return uni,Var(uni.le[0]+ str(id)) def create_connected_graph(role_scores,concepts,root_id,dependent,aligns, normalizeMod=True): #role_scores: amr x amr x rel graph = nx.MultiDiGraph() n = len(concepts) role_scores = role_scores.view(n,n,-1) max_non_score, max_non_score_id= role_scores[:,:,1:].max(-1) max_non_score_id = max_non_score_id +1 non_score = role_scores[:,:,0] active_cost = non_score - max_non_score #so lowest cost edge gets to active first candidates = [] # add all nodes for h_id in range(n): h,h_v = get_uni_var(concepts,h_id) # here align is a list to the token index in our tokenization graph.add_node(h_v, value=h, align=aligns[h_id],gold=True,dep = dependent[h_id]) constant_links = {} normal_edged_links = {} # add all pairs of edges for h_id in range(n): for d_id in range(n): if h_id != d_id: r = amr_rel_dict.getLabel(max_non_score_id[h_id,d_id].item()) # relations in amr_rel_dict are forward argx, opx, sntx top # and all backward relations. # normalize mod should already be done in amr_rel_dict if normalizeMod and to_be_normalize_mod(r): r = get_normalize_mod(r) # we should normalize it here when connecting, make sure to not consider the same relation twices r_inver = get_inversed_edge(r) h,h_v = get_uni_var(concepts,h_id) d,d_v = get_uni_var(concepts,d_id) if (concepts[h_id].is_constant() or concepts[d_id].is_constant() ): if concepts[h_id].is_constant() and concepts[d_id].is_constant() : continue elif concepts[h_id].is_constant(): constant_links.setdefault(h_v,[]).append((active_cost[h_id,d_id],d_v,r,r_inver)) else: constant_links.setdefault(d_v,[]).append((active_cost[h_id,d_id],h_v,r_inver,r)) elif active_cost[h_id,d_id] < 0: if r in [":name-of" ,":ARG0"] and concepts[d_id].le in ["person"]: # always adding two direction to support directed path for connectivity graph.add_edge(h_v, d_v, key=r, role=r) graph.add_edge(d_v, h_v, key=r_inver, role=r_inver) else: # if concepts[h_id].le == "name" and r != ":name-of": # r = ":name-of" normal_edged_links.setdefault((h_v,r),[]).append((active_cost[h_id,d_id],d_v,r_inver)) else: candidates.append((active_cost[h_id,d_id],(h_v,d_v,r,r_inver))) max_edge_per_node = 1 if not self.training else 100 for h_v,r in normal_edged_links: sorted_list = sorted(normal_edged_links[h_v,r],key = lambda j:j[0]) for _,d_v,r_inver in sorted_list[:max_edge_per_node]: # if graph.has_edge(h_v, d_v): # continue graph.add_edge(h_v, d_v, key=r, role=r) graph.add_edge(d_v, h_v, key=r_inver, role=r_inver) for cost,d_v,r_inver in sorted_list[max_edge_per_node:]: #remaining candidates.append((cost,(h_v,d_v,r,r_inver))) for h_v in constant_links: _,d_v,r,r_inver = sorted(constant_links[h_v],key = lambda j:j[0])[0] graph.add_edge(h_v, d_v, key=r, role=r) graph.add_edge(d_v, h_v, key=r_inver, role=r_inver) candidates = sorted(candidates,key = lambda j:j[0]) for _,(h_v,d_v,r,r_inver ) in candidates: if nx.is_strongly_connected(graph): break if not nx.has_path(graph,h_v,d_v): graph.add_edge(h_v, d_v, key=r, role=r,force_connect=True) graph.add_edge(d_v, h_v, key=r_inver, role=r_inver,force_connect=True) _,root_v = get_uni_var(concepts,root_id) h_v = BOS_WORD root_symbol = AMRUniversal(BOS_WORD,BOS_WORD,NULL_WORD) graph.add_node(h_v, value=root_symbol, align=-1,gold=True,dep=1) graph.add_edge(h_v, root_v, key=":top", role=":top") graph.add_edge(root_v, h_v, key=":top-of", role=":top-of") if get_sense: for n,d in graph.nodes(True): if "value" not in d: print (n,d, graph[n],constant_links,graph.nodes,graph.edges) le,cat,sense = d["value"].le,d["value"].cat,d["value"].sense if cat == Rule_Frame and sense == "": sense = self.fragment_to_node_converter.get_senses(le) d["value"] = AMRUniversal(le,cat,sense) if not self.training: # when not training, it didn't do contratact if not nx.is_strongly_connected(graph): logger.warn("not connected before contraction: %s, %s, %s, %s, %s, %s, %s, %s",self.graph_to_quadruples(graph),graph_to_amr(graph), candidates, constant_links, graph.nodes(), graph.edges(), normal_edged_links, concepts) graph = contract_graph(graph) if set_wiki: list = [[n,d]for n,d in graph.nodes(True)] for n,d in list: if d["value"].le == "name": names = [] head = None for nearb in graph[n]: for key, edge_data in graph[n][nearb].items(): r = edge_data["role"] if ":op" in r and "-of" not in r and int(edge_data["role"][3:]) not in names: names.append([graph.node[nearb]["value"], int(edge_data["role"][3:])]) if r == ":name-of": wikied = False for nearbb in graph[nearb]: for _, edge_data2 in graph[nearb][nearbb].items(): r2 = edge_data2["role"] if r2 == ":wiki": wikied = True break if not wikied: head = nearb if head: names = tuple([t[0] for t in sorted(names,key = lambda t: t[1])]) wiki = self.fragment_to_node_converter.get_wiki(names) # print (wiki) wiki_v = Var(wiki.le+n._name ) graph.add_node(wiki_v, value=wiki, align=d["align"],gold=True,dep=2) #second order dependency graph.add_edge(head, wiki_v, key=":wiki", role=":wiki") graph.add_edge(wiki_v, head, key=":wiki-of",role=":wiki-of") if not nx.is_strongly_connected(graph): logger.warn("not connected after contraction: %s, %s, %s, %s, %s, %s, %s, %s",self.graph_to_quadruples(graph),graph_to_amr(graph), candidates, constant_links, graph.nodes(), graph.edges(), normal_edged_links, concepts) return graph,self.graph_to_quadruples(graph) graphs = [] quadruple_batch = [] score_batch = myunpack(*srl_prob) #list of (h x d) depedent_mark_batch = appended[0] #here aligns_batch is how final expaned node aligned to the categorized node aligns_batch = appended[1] for i,(role_scores,concepts,roots_score,dependent_mark,aligns) in enumerate(zip(score_batch,srl_batch,roots,depedent_mark_batch,aligns_batch)): root_s,root_id = roots_score.max(0) assert roots_score.size(0) == len(concepts),(concepts,roots_score) # in pytorch 0.4.0 to https://pytorch.org/docs/stable/tensors.html#torch.Tensor.tolist # tensor.tolist can be a int, when root_id is a scalar root_id = root_id.item() assert root_id < len(concepts),(concepts,roots_score) g,quadruples = create_connected_graph(role_scores,concepts,root_id,dependent_mark,aligns, normalizeMod=normalizeMod) graphs.append(g) quadruple_batch.append(quadruples) return graphs,quadruple_batch