def build_graph(self, mg, src, dst, ntid, etid, ntypes, etypes): """Build the graphs Parameters ---------- mg: MultiDiGraph Input graph src: Numpy array Source nodes dst: Numpy array Destination nodes ntid: Numpy array Node types for each node etid: Numpy array Edge types for each edge ntypes: list Node types etypes: list Edge types Returns ------- g: DGLGraph """ # create h**o graph if self.verbose: print('Creating one whole graph ...') g = dgl.graph((src, dst)) g.ndata[dgl.NTYPE] = F.tensor(ntid) g.edata[dgl.ETYPE] = F.tensor(etid) if self.verbose: print('Total #nodes:', g.number_of_nodes()) print('Total #edges:', g.number_of_edges()) # rename names such as 'type' so that they an be used as keys # to nn.ModuleDict etypes = [RENAME_DICT.get(ty, ty) for ty in etypes] mg_edges = mg.edges(keys=True) mg = nx.MultiDiGraph() for sty, dty, ety in mg_edges: mg.add_edge(sty, dty, key=RENAME_DICT.get(ety, ety)) # convert to heterograph if self.verbose: print('Convert to heterograph ...') hg = dgl.to_heterogeneous(g, ntypes, etypes, metagraph=mg) if self.verbose: print('#Node types:', len(hg.ntypes)) print('#Canonical edge types:', len(hg.etypes)) print('#Unique edge type names:', len(set(hg.etypes))) return hg
def process(self): """process raw data to graph, labels and masks""" data = sio.loadmat( os.path.join(self.raw_path, f'{self.name}_retweet_graph.mat')) adj = data['graph'].tocoo() num_edges = len(adj.row) row, col = adj.row[:int(num_edges / 2)], adj.col[:int(num_edges / 2)] graph = dgl.graph((np.concatenate( (row, col)), np.concatenate((col, row)))) news_labels = data['label'].squeeze() num_news = len(news_labels) node_feature = np.load( os.path.join(self.raw_path, f'{self.name}_node_feature.npy')) edge_feature = np.load( os.path.join(self.raw_path, f'{self.name}_edge_feature.npy'))[:int(num_edges / 2)] graph.ndata['feat'] = th.tensor(node_feature) graph.edata['feat'] = th.tensor(np.tile(edge_feature, (2, 1))) pos_news = news_labels.nonzero()[0] edge_labels = th.zeros(num_edges) edge_labels[graph.in_edges(pos_news, form='eid')] = 1 edge_labels[graph.out_edges(pos_news, form='eid')] = 1 graph.edata['label'] = edge_labels ntypes = th.ones(graph.num_nodes(), dtype=int) etypes = th.ones(graph.num_edges(), dtype=int) ntypes[graph.nodes() < num_news] = 0 etypes[:int(num_edges / 2)] = 0 graph.ndata['_TYPE'] = ntypes graph.edata['_TYPE'] = etypes hg = dgl.to_heterogeneous(graph, ['v', 'u'], ['forward', 'backward']) self._random_split(hg, self.seed, self.train_size, self.val_size) self.graph = hg
def forward(self, **kwargs): """ :param input_ids: shape: [batch_size, max_seq_length (,1)]. e.g. [101 16068 1551 131 11253 10785 7637 3348 113 1286 114 1105 19734 1123 1493 113 1268 114 1112 1131 4927 1123 1159 1113 1103 2037 1437 1114 1123 3235 137 1282 14507 2636 102 1650 3696 9255 153 2591 13360 6258 3048 10069 131 5187 131 3927 142 9272 117 1367 1347 1381 197 19753 11392 12880 2137 131 1367 131 1512 142 9272 117 1367 1347 1381 11253 10785 7637 1144 3090 1131 1110 7805 1123 1148 2027 1114 20497 1389 27891 1667 11247 119 1109 3081 118 1214 118 1385 2851 117 1150 1640 1144 1300 1482 1121 2166 6085 117 1163 1107 1126 3669 1113 1109 4258 157 18963 7317 2737 3237 1115 1131 1110 17278 1106 1129 20028 1330 1901 1106 1123 9304 13465 119 1153 1163 131 112 1284 787 1396 1198 1276 1149 1195 787 1231 1515 170 2963 118 146 787 182 1210 1808 6391 119 146 1138 2094 1105 170 2963 1107 1139 7413 117 1103 1436 2053 1107 1103 1362 117 170 1632 2261 1105 146 787 182 170 1304 6918 1873 119 146 787 182 1304 9473 119 112 137 13426 11253 117 3081 117 1110 1210 1808 6391 1114 1123 3049 2963 137 13426 18662 18284 5208 2483 1163 1131 5115 1176 112 170 1304 6918 1873 112 137 13426 11253 1105 1393 4896 1591 1667 1508 1147 4655 1113 2080 1165 1131 1108 3332 19004 1111 170 1248 1159 1171 1107 1351 102] :param attention_mask: [batch_size, max_seq_length(, 1)]. e.g. [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] :param kwargs (optional input): start_positions: [batch_size(,1)] end_positions: [batch_size (,1)] token_type_ids: [batch_size, max_seq_length(, 1)]. e.g. [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] wordnet_concept_ids: [batch_size, max_seq_length, max_wn_length]. e.g. [[0,0,0,0,0],[0,1,0,0,0],[92,0,0,0,0],[0,0,0,0,0],[0,0,0,0,0],[0,0,0,0,0]] nell_concept_ids: [batch_size, max_seq_length, max_nell_length]. e.g. 0:[] 1:[] 2:[] 3:[] 4:[19, 92, 255] 5:[19, 92, 255] 6:[19, 92, 255] 7:[] 8:[] 9:[] 10:[] 11:[] 12:[] 13:[] 14:[] 15:[] 16:[] 17:[] 18:[] 19:[] 20:[] 21:[] 22:[] 23:[] 24:[] 25:[] 26:[] 27:[] 28:[] 29:[] 30:[] 31:[] 32:[] 33:[] 34:[] 35:[] 36:[] 37:[] 38:[] 39:[] 40:[] 41:[] 42:[] 43:[] 44:[] 45:[] 46:[] 47:[] 48:[] 49:[] 50:[] 51:[] 52:[] 53:[] 54:[] 55:[] 56:[] 57:[] 58:[] 59:[] 60:[] 61:[] 62:[] 63:[] 64:[] 65:[] 66:[] 67:[] 68:[] 69:[19, 92, 255] 70:[19, 92, 255] 71:[19, 92, 255] 72:[] 73:[] 74:[] 75:[] 76:[] 77:[] 78:[] 79:[] 80:[] 81:[] 82:[] 83:[] 84:[] 85:[] 86:[] 87:[] 88:[] 89:[] 90:[] 91:[] 92:[] 93:[] 94:[] 95:[] 96:[] 97:[] 98:[] 99:[] 100:[] 101:[] 102:[] 103:[] 104:[] 105:[] 106:[] 107:[] 108:[] 109:[] 110:[] 111:[] 112:[] 113:[] 114:[] 115:[] 116:[] 117:[] 118:[] 119:[] 120:[] 121:[] 122:[] 123:[] 124:[] 125:[] 126:[] 127:[] 128:[] 129:[] 130:[] 131:[] 132:[] 133:[] 134:[] 135:[] 136:[] 137:[] 138:[] 139:[] 140:[] 141:[] 142:[] 143:[] 144:[] 145:[] 146:[] 147:[] 148:[] 149:[] 150:[] 151:[] 152:[] 153:[] 154:[] 155:[] 156:[] 157:[] 158:[] 159:[] 160:[] 161:[] 162:[] 163:[] 164:[] 165:[] 166:[] 167:[] 168:[] 169:[] 170:[] 171:[] 172:[] 173:[] 174:[] 175:[] 176:[] 177:[] 178:[] 179:[] 180:[] 181:[] 182:[] 183:[] 184:[] 185:[] 186:[] 187:[] 188:[] 189:[] 190:[] 191:[] 192:[50, 239] 193:[] 194:[] 195:[] 196:[] 197:[] 198:[] 199:[] 200:[] 201:[] 202:[] 203:[] 204:[] 205:[] 206:[] 207:[] 208:[] 209:[] 210:[] 211:[] 212:[] 213:[] 214:[] 215:[] 216:[] 217:[] 218:[] 219:[] 220:[] 221:[] 222:[50, 239] 223:[] 224:[] 225:[] 226:[] 227:[138, 91] 228:[] 229:[] 230:[] 231:[] 232:[] 233:[] 234:[] 235:[] 236:[] 237:[] 238:[] 239:[] 240:[] 241:[] 242:[] 243:[] 244:[] 245:[] :return: """ # start_forward_time = time() label_ids_list = kwargs.get("label_ids") input_ids_list = kwargs.get("input_ids") # logger.info("rank:{}".format(input_ids.device)) attention_mask_list = kwargs.get("attention_mask") token_type_ids_list = kwargs.get("token_type_ids") batch_synset_graphs_id_list = kwargs.get("batch_synset_graphs") wn_synset_graphs_list = kwargs.get("wn_synset_graphs") choice_score_list = [] for num in range(2): label_ids = label_ids_list input_ids = input_ids_list[:, num, :] # logger.info("rank:{}".format(input_ids.device)) attention_mask = attention_mask_list[:, num, :] if self.config.text_embed_model == "bert": token_type_ids = token_type_ids_list[:, num, :] elif self.config.text_embed_model == "roberta" or self.config.text_embed_model == "roberta_base": token_type_ids = None batch_synset_graphs_id = batch_synset_graphs_id_list wn_synset_graphs = wn_synset_graphs_list[num] batch_synset_graphs = [wn_synset_graphs[i] for i in batch_synset_graphs_id] batch_context_graphs_list = [] batch_wn_graphs_list = [] batch_entity2token_graphs_list = [] batch_entity2token_graphs_nell_list = [] token_length_list = [] if self.config.text_embed_model == "bert": text_output = self.text_embed_model( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_attentions=self.config.output_attentions, output_hidden_states=self.config.output_hidden_states )[0] elif self.config.text_embed_model == "roberta" or self.config.text_embed_model == "roberta_base": text_output = self.text_embed_model( input_ids=input_ids, attention_mask=attention_mask )[0] relation_list = self.config.relation_list inverse_relation_list = [] # node_type in origin graph id_type_list = [] context_type_list = [] for i, relation_type in enumerate(relation_list): inverse_relation_list.append("{}_".format(relation_type)) id_type = "wn{}_id".format(relation_type) id_type_list.append(id_type) context_type = "wn{}_context".format(relation_type) context_type_list.append(context_type) # start_time = time() for i, g in enumerate(batch_synset_graphs): assert (len(g.nodes("token_id")) == torch.sum(attention_mask[i, :])) token_length_list.append(len(g.nodes("token_id"))) # reconstruct context graph context_g, wn_g = self.reconstruct_dgl_graph(g, relation_list, inverse_relation_list, id_type_list, context_type_list, text_output[i, :, :], input_ids.device) entity2token_graph, entity2token_graph_nell = self.construct_entity2token_graph(i, g, text_output, input_ids.device) batch_entity2token_graphs_list.append(entity2token_graph) batch_entity2token_graphs_nell_list.append(entity2token_graph_nell) batch_context_graphs_list.append(context_g) batch_wn_graphs_list.append(wn_g) batch_context_graphs_dgl = dgl.batch(batch_context_graphs_list) graph_context_embedding = self.rgcn_context(batch_context_graphs_dgl, batch_context_graphs_dgl.ndata['feature']) batch_context_graphs_dgl.nodes["wn_concept_context"].data["feature"] = graph_context_embedding[ "wn_concept_context"] # batch_context_graphs_dgl.nodes["wn_concept_context"].data["feature_project"] = self.bert_projected_token_ids( # graph_context_embedding["wn_concept_context"]) batch_context_graphs_list = dgl.unbatch(batch_context_graphs_dgl) batch_wn_graphs_dgl = dgl.batch(batch_wn_graphs_list) graph_wn_embedding = self.rgcn_wn(batch_wn_graphs_dgl, batch_wn_graphs_dgl.ndata['feature']) batch_wn_graphs_dgl.nodes["wn_concept_id"].data["feature"] = graph_wn_embedding["wn_concept_id"] batch_wn_graphs_list = dgl.unbatch(batch_wn_graphs_dgl) memory_output_new = text_output # batch_entity2token_graphs_list_homo_s = [] context_embed_new = torch.zeros( (memory_output_new.shape[0], memory_output_new.shape[1], self.concept_embed_size), dtype=torch.float32, device=input_ids.device) concept_embed_new = torch.zeros( (memory_output_new.shape[0], memory_output_new.shape[1], self.concept_embed_size), dtype=torch.float32, device=input_ids.device) nell_embed_new = torch.zeros( (memory_output_new.shape[0], memory_output_new.shape[1], self.concept_embed_size), dtype=torch.float32, device=input_ids.device) # start_time = time() for idx, g_e2t in enumerate(batch_entity2token_graphs_list): g_e2t.nodes["wn_concept_id"].data["context_feature"] = batch_context_graphs_list[idx].nodes["wn_concept_context"].data["feature"] # logger.info("idx {}: {}".format(idx, g_e2t.nodes["wn_concept_id"].data["context_feature"])) g_e2t.nodes["wn_concept_id"].data["id_feature"] = batch_wn_graphs_list[idx].nodes["wn_concept_id"].data["feature"] g_e2t.nodes["token_id"].data["id_feature"] = self.projected_token_text(g_e2t.nodes["token_id"].data["context_feature"]) g_e2t.nodes["sentinel_id"].data["id_feature"] = torch.zeros_like(g_e2t.nodes["token_id"].data["id_feature"], device=input_ids.device) g_e2t.nodes["sentinel_id"].data["context_feature"] = torch.zeros_like(g_e2t.nodes["token_id"].data["context_feature"], device=input_ids.device) g_e2t_homo = dgl.to_homogeneous(g_e2t, ndata=['id_feature', 'context_feature']) g_e2t_homo.ndata['context_feature'] = self.gat_context(g_e2t_homo, g_e2t_homo.ndata['context_feature']) g_e2t_homo.ndata['id_feature'] = self.gat_wn(g_e2t_homo, g_e2t_homo.ndata['id_feature']) tmp_graph = dgl.to_heterogeneous(g_e2t_homo, g_e2t.ntypes, g_e2t.etypes) tmp_argsort = torch.argsort(tmp_graph.ndata[dgl.NID]["token_id"] - tmp_graph.num_nodes("sentinel_id")) concept_embed_new[idx, :tmp_graph.num_nodes("token_id"), :] = tmp_graph.nodes["token_id"].data[ "id_feature"].index_select(0, tmp_argsort) context_embed_new[idx, :tmp_graph.num_nodes("token_id"), :] = tmp_graph.nodes["token_id"].data[ "context_feature"].index_select(0, tmp_argsort) g_e2t_nell = batch_entity2token_graphs_nell_list[idx] g_e2t_nell_homo = dgl.to_homogeneous(g_e2t_nell, ndata=['id_feature']) g_e2t_nell_homo.ndata['id_feature'] = self.gat_nell(g_e2t_nell_homo, g_e2t_nell_homo.ndata['id_feature']) tmp_graph_nell = dgl.to_heterogeneous(g_e2t_nell_homo, g_e2t_nell.ntypes, g_e2t_nell.etypes) nell_tmp_argsort = torch.argsort(tmp_graph_nell.ndata[dgl.NID]["token_id"] - tmp_graph_nell.num_nodes("sentinel_id") - tmp_graph_nell.num_nodes("nell_concept_id")) nell_embed_new[idx, :tmp_graph_nell.num_nodes("token_id"), :] = tmp_graph_nell.nodes["token_id"].data["id_feature"].index_select(0, nell_tmp_argsort) # # logger.info("time for one by one: {}".format(time() - start_time)) if self.use_nell: memory_output_new = torch.cat((memory_output_new, nell_embed_new), 2) if self.use_context_graph and self.use_wn: k_memory = torch.cat((concept_embed_new, context_embed_new), 2) elif self.use_wn: k_memory = concept_embed_new elif self.use_context_graph: k_memory = context_embed_new if self.use_context_graph or self.use_wn: memory_output_new = torch.cat((memory_output_new, k_memory), 2) att_output = self.self_matching(memory_output_new, attention_mask.unsqueeze(2)) # [batch_size, max_seq_length, memory_output_size] # 4th layer: output layer choice_score = self.qa_kt_outputs(att_output[:,0,:]) choice_score_list.append(choice_score) logits = torch.cat( [choice_score.unsqueeze(1).squeeze(-1) for choice_score in choice_score_list], dim=1 ) if label_ids[0] != -1: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, 2), label_ids.view(-1)) else: loss = None # logger.info("time for forward: {}".format(time()-start_forward_time)) return loss, logits, label_ids, kwargs.get("qas_ids")
def test_edge_softmax(g, norm_by, idtype): print("params", norm_by, idtype) g = create_test_heterograph(idtype) x1 = F.randn((g.num_edges('plays'), feat_size)) x2 = F.randn((g.num_edges('follows'), feat_size)) x3 = F.randn((g.num_edges('develops'), feat_size)) x4 = F.randn((g.num_edges('wishes'), feat_size)) F.attach_grad(F.clone(x1)) F.attach_grad(F.clone(x2)) F.attach_grad(F.clone(x3)) F.attach_grad(F.clone(x4)) g['plays'].edata['eid'] = x1 g['follows'].edata['eid'] = x2 g['develops'].edata['eid'] = x3 g['wishes'].edata['eid'] = x4 ################################################################# # edge_softmax() on homogeneous graph ################################################################# with F.record_grad(): hm_g = dgl.to_homogeneous(g) hm_x = F.cat((x3, x2, x1, x4), 0) hm_e = F.attach_grad(F.clone(hm_x)) score_hm = edge_softmax(hm_g, hm_e, norm_by=norm_by) hm_g.edata['score'] = score_hm ht_g = dgl.to_heterogeneous(hm_g, g.ntypes, g.etypes) r1 = ht_g.edata['score'][('user', 'plays', 'game')] r2 = ht_g.edata['score'][('user', 'follows', 'user')] r3 = ht_g.edata['score'][('developer', 'develops', 'game')] r4 = ht_g.edata['score'][('user', 'wishes', 'game')] F.backward(F.reduce_sum(r1) + F.reduce_sum(r2)) grad_edata_hm = F.grad(hm_e) ################################################################# # edge_softmax() on heterogeneous graph ################################################################# e1 = F.attach_grad(F.clone(x1)) e2 = F.attach_grad(F.clone(x2)) e3 = F.attach_grad(F.clone(x3)) e4 = F.attach_grad(F.clone(x4)) e = { ('user', 'follows', 'user'): e2, ('user', 'plays', 'game'): e1, ('user', 'wishes', 'game'): e4, ('developer', 'develops', 'game'): e3 } with F.record_grad(): score = edge_softmax(g, e, norm_by=norm_by) r5 = score[('user', 'plays', 'game')] r6 = score[('user', 'follows', 'user')] r7 = score[('developer', 'develops', 'game')] r8 = score[('user', 'wishes', 'game')] F.backward(F.reduce_sum(r5) + F.reduce_sum(r6)) grad_edata_ht = F.cat((F.grad(e3), F.grad(e2), F.grad(e1), F.grad(e4)), 0) # correctness check assert F.allclose(r1, r5) assert F.allclose(r2, r6) assert F.allclose(r3, r7) assert F.allclose(r4, r8) assert F.allclose(grad_edata_hm, grad_edata_ht)
def _extract_features(self, graphs, ai2d_ann, image, layers): """ Extracts features from the original AI2D annotation and adds them to the AI2D-RST graphs. Parameters: graphs: A dictionary of NetworkX graphs for AI2D-RST annotation. ai2d_ann: A dictionary containing the original AI2D annotation. image: An image of the diagram from the original AI2D dataset. layers: A string defining annotation layers to include in the updated graphs. Returns: A dictionary of NetworkX graphs with updated features. """ # To begin with, build the grouping graph, which is provides the layout # information on all diagram elements, which can be then picked out in # other graphs, if necessary. graph = graphs['grouping'] # Check that a graph exists try: # Fetch nodes from the graph nodes = graph.nodes(data=True) except AttributeError: return None # Begin extracting the features by getting the diagram image shape h, w = image.shape[:2] # Get the number of pixels in the image n_pix = h * w # Set up a placeholder dictionaries to hold updated node and edge # features node_features = {} edge_features = {} # Loop over the nodes and their features for node, features in nodes: # Fetch the node type from its features under the key 'kind' node_type = features['kind'] # Parse layout annotation layout_feats = self._parse_ai2d_layout( ai2d_ann, # annotation h, # image height w, # image width n_pix, # n of pixels node_type, # elem type node # node id ) # Add layout features to the dictionary of updated node features node_features[node] = { 'features': layout_feats, 'kind': self.node_dict['grouping'][node_type] } # Updated node attributes in the grouping graph using layout # features nx.set_node_attributes(graph, node_features) # Calculate features for grouping nodes based on their children. This # requires a directed tree graph. group_tree = nx.dfs_tree(graph, source="I0") # Get a list of grouping nodes and image constants in the graph groups = [ n for n, attr in graph.nodes(data=True) if attr['kind'] in [ self.node_dict['grouping']['imageConsts'], self.node_dict['grouping']['group'] ] ] # Iterate over the nodes in the graph for n, attr in graph.nodes(data=True): # Check if the node type is a group if n in groups: # Get predecessors of the grouping node n_preds = nx.dfs_predecessors(group_tree, n) # Remove groups from the list of predecessor; # each group will be processed indepedently n_preds = [n for n in n_preds.keys() if n not in groups] # Create a subgraph consisting of preceding nodes n_subgraph = graph.subgraph(n_preds) # Get layout features for each node n_feats = [ ad['features'] for n, ad in n_subgraph.nodes(data=True) ] # Cast stacked features into a 2D numpy array stacked_feats = np.array(n_feats) # Get average centre point for group by slicing the array x_avg = np.average(stacked_feats[:, 0]) y_avg = np.average(stacked_feats[:, 1]) # Add up their area a_sum = np.sum(stacked_feats[:, 2]) # Average the solidity s_avg = np.average(stacked_feats[:, 3]) # Concatenate the features layout_feats = np.concatenate( [[x_avg], [y_avg], [a_sum], [s_avg]], axis=0) # Update group feature dictionary upd_group_feats = { n: { 'features': layout_feats, 'kind': attr['kind'] } } # Update group features nx.set_node_attributes(graph, upd_group_feats) # Add edge types to the grouping layer, as these are not defined in the # JSON annotation. To do so, get the edges from the grouping graph. edges = graph.edges(data=True) # Loop over the edges in the graph for src, dst, features in edges: # Add edge type unde key 'kind' to the edge_features dictionary edge_features[src, dst] = {'kind': 'grouping'} # Update edge features in the grouping graph nx.set_edge_attributes(graph, edge_features) # Encode edge features self._encode_edges(graph, self.edge_dict['grouping']) # Update the grouping graph in the graphs dictionary graphs['grouping'] = graph # Now that the grouping layer has been created, check which other # annotation layers must be included in the graph-based representation. # The combination of grouping and connectivity layers is a relatively # simple case. if layers == "grouping+connectivity": # If a connectivity graph exists, merge it with the grouping graph if graphs['connectivity'] is not None: # Use nx.compose() to combine the grouping and connectivity # graphs graph = nx.compose(graphs['connectivity'], graphs['grouping']) # Encode edge type information using numerical labels self._encode_edges(graph, self.edge_dict['connectivity']) # Update the grouping graph graphs['grouping'] = graph # The connectivity layer alone is a bit more complex, as the children of # grouping nodes need to be copied over to the connectivity graph. if layers == 'connectivity' and graphs['connectivity'] is not None: # Get the grouping and connectivity graphs conn_graph = graphs['connectivity'] group_graph = graphs['grouping'] # Get a list of nodes in the connectivity graph conn_nodes = list(conn_graph.nodes(data=True)) # Get a list of grouping nodes in the connectivity graph grouping_nodes = [ n for n, attr_dict in conn_nodes if attr_dict['kind'] == 'group' ] # If grouping nodes are found, get their children and add them to # the graph if len(grouping_nodes) > 0: # Create a directed tree graph using depth-first search, # starting from the image constant I0. group_tree = nx.dfs_tree(group_graph, source="I0") # Loop over each grouping node for gn in grouping_nodes: # Resolve grouping nodes by adding their children to the # connectivity graph self._resolve_grouping_node(gn, group_tree, group_graph, conn_graph) # If the connectivity graph does not include grouping nodes, simply # copy the node features from the grouping graph. n_subgraph = group_graph.subgraph(conn_graph.nodes) # Add these nodes to the connectivity graph conn_graph.add_nodes_from(n_subgraph.nodes(data=True)) # Encode edge type information using numerical labels self._encode_edges(conn_graph, self.edge_dict['connectivity']) # Update the connectivity graph in the graphs dictionary graphs['connectivity'] = conn_graph # Start building the discourse graph by getting node features from the # grouping graph. if layers == 'discourse': # Get grouping and discourse graphs group_graph = graphs['grouping'] rst_graph = graphs['discourse'] # Reverse node type dictionary for the grouping layer rev_group_dict = { int(v.item()): k for k, v in self.node_dict['grouping'].items() } # Re-encode node types to ensure that node types do not clash with # those defined for discourse graph upd_node_types = { k: rev_group_dict[int(v['kind'].item())] for k, v in group_graph.nodes(data=True) } # Update node attributes for the grouping graph nx.set_node_attributes(group_graph, upd_node_types, 'kind') # Get the nodes participating in the discourse graph from the # grouping graph using the .subgraph() method. subgraph = group_graph.subgraph(rst_graph.nodes) # Add these nodes back to the discourse graph with their features # and numerical labels. These will overwrite the original nodes. rst_graph.add_nodes_from(subgraph.nodes(data=True)) # Check if discourse graph contains groups or split nodes. Split # nodes are used to preserve the tree structure in case a diagram # element participates in multiple RST relations. for n, attr_dict in rst_graph.copy().nodes(data=True): # Check if the node is a group if 'group' in attr_dict['kind']: # Create a directed tree graph using depth-first search, # starting from the image constant I0. group_tree = nx.dfs_tree(group_graph, source="I0") # Resolve grouping nodes by adding their children to the # discourse graph. self._resolve_grouping_node(n, group_tree, group_graph, rst_graph) # Check node for the copy_of attribute, which contains a # reference to the node which has been split. if 'copy_of' in attr_dict.keys(): # Get the identifier of the node in AI2D layout annotation n_orig_id = attr_dict['copy_of'] n_orig_kind = attr_dict['kind'] # Fetch node data from the AI2D layout annotation layout_feats = self._parse_ai2d_layout( ai2d_ann, h, w, n_pix, n_orig_kind, n_orig_id) # Add updated features to a dictionary upd_node_feats = { n: { 'features': layout_feats, 'kind': n_orig_kind } } # Update node features in the graph nx.set_node_attributes(rst_graph, upd_node_feats) # Check if the node is a relation if 'relation' in attr_dict['kind']: # Get integer label for RST relation rst_int_label = self.node_dict['relations'][ attr_dict['rel_name']] # Get node labels and encode using label binarizer rst_label = self._rst_binarizer.transform(rst_int_label) # Check if label smoothing is requested: if self._smooth_labels: # Cast into float for label smoothing rst_label = np.asarray(rst_label, dtype=np.float64) # Smooth the labels by a factor of 0.1 rst_label *= (1 - 0.1) rst_label += (0.1 / rst_label.shape[1]) # Store encoded information into the updated features dict upd_node_feats = {n: {'features': rst_label.flatten()}} # Set the updated features to nodes in the discourse graph nx.set_node_attributes(rst_graph, upd_node_feats) # Check if a NetworkX graph should be returned if self._return_nx: return rst_graph # Convert node identifiers to integers. This needs to be performed # before creating a heterograph. rst_graph = nx.convert_node_labels_to_integers(rst_graph, first_label=0) # Get nodes and convert to NumPy array; get unique nodes; get node # type index vector nodes = np.asarray([ attr['kind'] for n, attr in rst_graph.nodes(data=True) ]).flatten() ntypes = np.unique(nodes) node_ixs = np.array( [np.where(ntypes == n) for n in np.nditer(nodes)], dtype=np.int64).flatten() # Do the same for edges edges = np.asarray([ attr['kind'] for s, t, attr in rst_graph.edges(data=True) ]).flatten() etypes = np.unique(edges) edge_ixs = np.array( [np.where(etypes == e) for e in np.nditer(edges)], dtype=np.int64).flatten() # Create DGL graph object from the discourse graph g = dgl.from_networkx(rst_graph) # Assign node and edge types g.ndata[dgl.NTYPE] = torch.LongTensor(node_ixs) g.edata[dgl.ETYPE] = torch.LongTensor(edge_ixs) # Create a DGL heterograph from the DGL graph object hg = dgl.to_heterogeneous(g, ntypes, etypes) # Loop over node types in the heterograph for ntype in hg.ntypes: # Get unique node identifiers for this node type; cast to list rst_node_ids = hg.nodes[ntype].data[dgl.NID].tolist() # Loop over RST node identifiers features = np.vstack([ rst_graph.nodes[node_id]['features'] for node_id in rst_node_ids ]) # Add features to DGL heterograph hg.nodes[ntype].data['features'] = torch.from_numpy(features) # Update the RST graph graphs['discourse'] = hg # Return all graphs return graphs