def forward(self, gl, gr): gl = gl.local_var() gr = gr.local_var() xl = gl.ndata["feat"] xr = gr.ndata["feat"] for idx in range(self.depth): mu_lr, mu_rl = self.attention(xl, xr) gl.ndata["x"] = xl gr.ndata["x"] = xr gl.update_all(dgl.function.copy_src(src='x', out='m'), dgl.function.sum(msg='m', out='x')) gr.update_all(dgl.function.copy_src(src='x', out='m'), dgl.function.sum(msg='m', out='x')) xl = gl.ndata["x"] xr = gr.ndata["x"] xl = torch.cat([xl, mu_rl], dim=-1) xr = torch.cat([xr, mu_lr], dim=-1) xl = getattr(self, "ff%s" % idx)(xl) xr = getattr(self, "ff%s" % idx)(xr) gl.ndata["x"] = xl gr.ndata["x"] = xr xl = dgl.sum_nodes(gl, "x") xr = dgl.sum_nodes(gr, "x") x = self.ff(torch.cat([xl, xr], dim=-1)) return x
def forward(self, graph, node_feature, qs, ally_node_type_index=NODE_ALLY): assert isinstance(graph, dgl.BatchedDGLGraph) w_emb = self.w_gn(graph, node_feature) # [# nodes x # node_dim] w = torch.abs(self.w_ff(graph, w_emb)) # [# nodes x # 1] ally_node_indices = get_filtered_node_index_by_type( graph, ally_node_type_index) device = w_emb.device _qs = torch.zeros(size=(graph.number_of_nodes(), 1), device=device) w = w[ally_node_indices, :] # [# allies x 1] _qs[ally_node_indices, :] = w * qs.view(-1, 1) graph.ndata['node_feature'] = _qs q_tot = dgl.sum_nodes(graph, 'node_feature') _ = graph.ndata.pop('node_feature') v_emb = self.v_gn(graph, node_feature) # [# nodes x # node_dim] v = self.v_ff(graph, v_emb) # [# nodes x # 1] v = v[ally_node_indices, :] # [# allies x 1] _v = torch.zeros(size=(graph.number_of_nodes(), 1), device=device) _v[ally_node_indices, :] = v graph.ndata['node_feature'] = _v v = dgl.sum_nodes(graph, 'node_feature') _ = graph.ndata.pop('node_feature') q_tot = q_tot + v return q_tot.view(-1)
def forward(self, g, h, e, snorm_n, snorm_e): # modified dtype for new dataset h = h.float() h = self.embedding_lin(h.cuda()) h_in = h # for residual connection # list of hidden representation at each layer (including input) hidden_rep = [h] for i in range(self.n_layers): h = self.ginlayers[i](g, h, snorm_n) # Residual Connection if self.residual: if self.residual == "gated": z = torch.sigmoid(self.W_g(torch.cat([h, h_in], dim=1))) h = z * h + (torch.ones_like(z) - z) * h_in else: h += h_in g.ndata['h'] = self.linear_ro(h) if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.sum_nodes(g, 'h') # default readout is summation score = self.linear_prediction(hg) return score
def forward(self, g, h, e, snorm_n, snorm_e): # modified dtype for new dataset h = h.float() h = self.embedding_lin(h) h = self.in_feat_dropout(h) for conv in self.layers: h_in = h h = conv(g, h, snorm_n) if self.residual: if self.residual == "gated": z = torch.sigmoid(self.W_g(torch.cat([h, h_in], dim=1))) h = z * h + (torch.ones_like(z) - z) * h_in else: h += h_in g.ndata['h'] = self.linear_ro(h) if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.sum_nodes(g, 'h') # default readout is summation return self.linear_predict(hg)
def forward(self, g): g.nodes['word'].data['feat'] = self.dropout( self.word_embedding(g.nodes['word'].data['x'])) g.nodes['concept'].data['feat'] = self.dropout( self.concept_embedding(g.nodes['concept'].data['x'])) g.edges['A'].data['weight'] = self.dropout( self.w_w_embedding(g.edges['A'].data['h'])) g.edges['B'].data['weight'] = self.dropout( self.w_c_embedding(g.edges['B'].data['h'])) g.edges['C'].data['weight'] = self.dropout( self.c_w_embedding(g.edges['C'].data['h'])) # g.nodes['word'].data['feat'] = self.word_embedding(g.nodes['word'].data['x']) # g.nodes['concept'].data['feat'] = self.concept_embedding(g.nodes['concept'].data['x']) # g.edges['A'].data['weight'] = self.w_w_embedding(g.edges['A'].data['h']) # g.edges['B'].data['weight'] = self.w_c_embedding(g.edges['B'].data['h']) # g.edges['C'].data['weight'] = self.c_w_embedding(g.edges['C'].data['h']) h = g.ndata['feat'] h = self.rgcn(g, h) with g.local_scope(): g.ndata['h'] = h # Calculate graph representation by average readout. # hg = 0 hg = torch.cat((dgl.sum_nodes( g, 'h', ntype='word'), dgl.sum_nodes(g, 'h', ntype='concept')), -1) # for ntype in g.ntypes: # hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype) # r = torch.cat(dgl.sum_nodes(g, 'h', ntype=ntype)) return self.classify(hg)
def get_q(self, graph, node_feature, qs, ws=None): device = node_feature.device ally_indices = get_filtered_node_index_by_type(graph, NODE_ALLY) # compute weighted sum of qs if ws is None: ws = self.get_w(graph, node_feature) # [#. allies x #. clusters] weighted_q = qs.view(-1, 1) * ws # [#. allies x #. clusters] qs = torch.zeros(size=(graph.number_of_nodes(), self.num_clusters), device=device) qs[ally_indices, :] = weighted_q graph.ndata['q'] = qs q_aggregated = dgl.sum_nodes(graph, 'q') # [#. graph x #. clusters] # compute state_dependent_bias graph.ndata['node_feature'] = node_feature sum_node_feature = dgl.sum_nodes(graph, 'node_feature') # [#. graph x feature dim] q_v = self.q_b_net(sum_node_feature) # [#. graph x #. clusters] _ = graph.ndata.pop('node_feature') _ = graph.ndata.pop('q') q_aggregated = q_aggregated + q_v return q_aggregated # [#. graph x #. clusters]
def test_simple_readout(): g1 = dgl.DGLGraph() g1.add_nodes(3) g2 = dgl.DGLGraph() g2.add_nodes(4) # no edges g1.add_edges([0, 1, 2], [2, 0, 1]) n1 = F.randn((3, 5)) n2 = F.randn((4, 5)) e1 = F.randn((3, 5)) s1 = F.sum(n1, 0) # node sums s2 = F.sum(n2, 0) se1 = F.sum(e1, 0) # edge sums m1 = F.mean(n1, 0) # node means m2 = F.mean(n2, 0) me1 = F.mean(e1, 0) # edge means w1 = F.randn((3, )) w2 = F.randn((4, )) max1 = F.max(n1, 0) max2 = F.max(n2, 0) maxe1 = F.max(e1, 0) ws1 = F.sum(n1 * F.unsqueeze(w1, 1), 0) ws2 = F.sum(n2 * F.unsqueeze(w2, 1), 0) wm1 = F.sum(n1 * F.unsqueeze(w1, 1), 0) / F.sum(F.unsqueeze(w1, 1), 0) wm2 = F.sum(n2 * F.unsqueeze(w2, 1), 0) / F.sum(F.unsqueeze(w2, 1), 0) g1.ndata['x'] = n1 g2.ndata['x'] = n2 g1.ndata['w'] = w1 g2.ndata['w'] = w2 g1.edata['x'] = e1 assert F.allclose(dgl.sum_nodes(g1, 'x'), s1) assert F.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1) assert F.allclose(dgl.sum_edges(g1, 'x'), se1) assert F.allclose(dgl.mean_nodes(g1, 'x'), m1) assert F.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1) assert F.allclose(dgl.mean_edges(g1, 'x'), me1) assert F.allclose(dgl.max_nodes(g1, 'x'), max1) assert F.allclose(dgl.max_edges(g1, 'x'), maxe1) g = dgl.batch([g1, g2]) s = dgl.sum_nodes(g, 'x') m = dgl.mean_nodes(g, 'x') max_bg = dgl.max_nodes(g, 'x') assert F.allclose(s, F.stack([s1, s2], 0)) assert F.allclose(m, F.stack([m1, m2], 0)) assert F.allclose(max_bg, F.stack([max1, max2], 0)) ws = dgl.sum_nodes(g, 'x', 'w') wm = dgl.mean_nodes(g, 'x', 'w') assert F.allclose(ws, F.stack([ws1, ws2], 0)) assert F.allclose(wm, F.stack([wm1, wm2], 0)) s = dgl.sum_edges(g, 'x') m = dgl.mean_edges(g, 'x') max_bg_e = dgl.max_edges(g, 'x') assert F.allclose(s, F.stack([se1, F.zeros(5)], 0)) assert F.allclose(m, F.stack([me1, F.zeros(5)], 0)) assert F.allclose(max_bg_e, F.stack([maxe1, F.zeros(5)], 0))
def forward(self, graph, node_feature, sub_q_tots): graph.ndata['node_feature'] = node_feature device = node_feature.device len_groups = [] q_tot = torch.zeros(graph.batch_size, device=device) for i, sub_q_tot in enumerate(sub_q_tots): node_indices = get_filtered_node_index_by_assignment(graph, i) len_groups.append(len(node_indices)) mask = torch.zeros(size=(node_feature.shape[0], 1), device=device) mask[node_indices, :] = 1 graph.ndata[ 'masked_node_feature'] = graph.ndata['node_feature'] * mask w_input = dgl.sum_nodes(graph, 'masked_node_feature') if self.rectifier == 'abs': q_tot = q_tot + torch.abs(self.w(w_input)).view(-1) * sub_q_tot elif self.rectifier == 'softplus': q_tot = q_tot + F.softplus( self.w(w_input)).view(-1) * sub_q_tot else: raise RuntimeError("Not implemented rectifier") _ = graph.ndata.pop('masked_node_feature') # testing _ = graph.ndata.pop('node_feature') ally_indices = get_filtered_node_index_by_type(graph, NODE_ALLY) _v = torch.zeros(size=(graph.number_of_nodes(), node_feature.shape[1]), device=device) _v[ally_indices, :] = node_feature[ally_indices, :] graph.ndata['node_feature'] = _v # testing v = self.v(dgl.sum_nodes(graph, 'node_feature')).view(-1) q_tot = q_tot + v _ = graph.ndata.pop('node_feature') len_groups = np.array(len_groups) ratio_groups = len_groups / np.sum(len_groups) print("Num elements in groups {}".format(len_groups)) print("Num elements ratio {}".format(ratio_groups)) ally_indices = get_filtered_node_index_by_type(graph, NODE_ALLY) target_assignment_weight = graph.ndata['normalized_score'][ ally_indices] print("Average normalized scores {}".format( target_assignment_weight.mean(0))) return q_tot
def _take_action(self, action): undecided = self.x == 2 self.x[undecided] = action[undecided] self.t += 1 x1 = (self.x == 1) self.g = self.g.to(self.device) self.g.ndata['h'] = x1.float() self.g.update_all(fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='h')) x1_deg = self.g.ndata.pop('h') ## forgive clashing clashed = x1 & (x1_deg > 0) self.x[clashed] = 2 x1_deg[clashed] = 0 # graph clean up still_undecided = (self.x == 2) self.x[still_undecided & (x1_deg > 0)] = 0 # fill timeout with zeros still_undecided = (self.x == 2) timeout = (self.t == self.max_epi_t) self.x[still_undecided & timeout] = 0 done = self._check_done() self.epi_t[~done] += 1 # compute reward and solution x1 = (self.x == 1).float() node_sol = x1 h = node_sol self.g.ndata['h'] = h next_sol = dgl.sum_nodes(self.g, 'h') self.g.ndata.pop('h') reward = (next_sol - self.sol) if self.hamming_reward_coef > 0.0 and self.num_samples == 2: xl, xr = self.x.split(1, dim=1) undecidedl, undecidedr = undecided.split(1, dim=1) hamming_d = torch.abs(xl.float() - xr.float()) hamming_d[(xl == 2) | (xr == 2)] = 0.0 hamming_d[~undecidedl & ~undecidedr] = 0.0 self.g.ndata['h'] = hamming_d hamming_reward = dgl.sum_nodes(self.g, 'h').expand_as(reward) self.g.ndata.pop('h') reward += self.hamming_reward_coef * hamming_reward reward /= self.max_num_nodes return reward, next_sol, done
def forward(self, g, pos): normalizer = torch.tensor(g.batch_num_nodes).unsqueeze_(1).float().to( pos.device) g.ndata['a_gp'] = (pos == 0).float() gp_embed = dgl.sum_nodes(g, 'h', 'a_gp') / normalizer g.ndata['a_p'] = (pos == 1).float() p_embed = dgl.mean_nodes(g, 'h', 'a_p') g.ndata['a_sib'] = (pos == 2).float() sib_embed = dgl.sum_nodes(g, 'h', 'a_sib') / normalizer return torch.cat((gp_embed, p_embed, sib_embed), 1)
def test_sum_case1(idtype): # NOTE: If you want to update this test case, remember to update the docstring # example too!!! g1 = dgl.graph(([0, 1], [1, 0]), idtype=idtype, device=F.ctx()) g1.ndata['h'] = F.tensor([1., 2.]) g2 = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx()) g2.ndata['h'] = F.tensor([1., 2., 3.]) bg = dgl.batch([g1, g2]) bg.ndata['w'] = F.tensor([.1, .2, .1, .5, .2]) assert F.allclose(F.tensor([3.]), dgl.sum_nodes(g1, 'h')) assert F.allclose(F.tensor([3., 6.]), dgl.sum_nodes(bg, 'h')) assert F.allclose(F.tensor([.5, 1.7]), dgl.sum_nodes(bg, 'h', 'w'))
def test_simple_readout(): g1 = dgl.DGLGraph() g1.add_nodes(3) g2 = dgl.DGLGraph() g2.add_nodes(4) # no edges g1.add_edges([0, 1, 2], [2, 0, 1]) n1 = th.randn(3, 5) n2 = th.randn(4, 5) e1 = th.randn(3, 5) s1 = n1.sum(0) # node sums s2 = n2.sum(0) se1 = e1.sum(0) # edge sums m1 = n1.mean(0) # node means m2 = n2.mean(0) me1 = e1.mean(0) # edge means w1 = th.randn(3) w2 = th.randn(4) ws1 = (n1 * w1[:, None]).sum(0) # weighted node sums ws2 = (n2 * w2[:, None]).sum(0) wm1 = (n1 * w1[:, None]).sum(0) / w1[:, None].sum(0) # weighted node means wm2 = (n2 * w2[:, None]).sum(0) / w2[:, None].sum(0) g1.ndata['x'] = n1 g2.ndata['x'] = n2 g1.ndata['w'] = w1 g2.ndata['w'] = w2 g1.edata['x'] = e1 assert U.allclose(dgl.sum_nodes(g1, 'x'), s1) assert U.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1) assert U.allclose(dgl.sum_edges(g1, 'x'), se1) assert U.allclose(dgl.mean_nodes(g1, 'x'), m1) assert U.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1) assert U.allclose(dgl.mean_edges(g1, 'x'), me1) g = dgl.batch([g1, g2]) s = dgl.sum_nodes(g, 'x') m = dgl.mean_nodes(g, 'x') assert U.allclose(s, th.stack([s1, s2], 0)) assert U.allclose(m, th.stack([m1, m2], 0)) ws = dgl.sum_nodes(g, 'x', 'w') wm = dgl.mean_nodes(g, 'x', 'w') assert U.allclose(ws, th.stack([ws1, ws2], 0)) assert U.allclose(wm, th.stack([wm1, wm2], 0)) s = dgl.sum_edges(g, 'x') m = dgl.mean_edges(g, 'x') assert U.allclose(s, th.stack([se1, th.zeros(5)], 0)) assert U.allclose(m, th.stack([me1, th.zeros(5)], 0))
def euclidean_matrix(graphs, dims, readout='sum'): '''Returns the pairwise euclidean distance between readout feature from all graphs. graphs : list of dgl graphs dims : graph features are concatenation of features obtained from all iterations, and this variable has the individual feature dimensions for the iterations. ''' graphs = dgl.batch(graphs) if readout == 'sum': graph_reprs = dgl.sum_nodes(graphs, 'h') elif readout == 'mean': graph_reprs = dgl.mean_nodes(graphs, 'h') else: raise ValueError('Readout for gram_matrix shall be either "mean" or "sum"') distances = [] dims = np.cumsum([0] + dims) with torch.no_grad(): for dim_start, dim_end in zip(dims, dims[1:]): features = graph_reprs[:, dim_start:dim_end] matrix = dist_matrix(features, features) distances.append(matrix.cpu().numpy()) return distances
def gram_matrix(graphs, dims, dist_fn, readout='sum'): '''Wrapper function to compute the gram matrix on graphs in batch, returns list of gram matrices as numpy.array graphs : list of dgl graphs dims : graph features are concatenation of features obtained from all iterations, and this variable has the individual feature dimensions for the iterations. ''' graphs = dgl.batch(graphs) if readout == 'sum': graph_reprs = dgl.sum_nodes(graphs, 'h') elif readout == 'mean': graph_reprs = dgl.mean_nodes(graphs, 'h') else: raise ValueError('Readout for gram_matrix shall be either "mean" or "sum"') distances = [] dims = np.cumsum([0] + dims) with torch.no_grad(): for dim_start, dim_end in zip(dims, dims[1:]): features = graph_reprs[:, dim_start:dim_end] gram_matrix = dist_fn(features, features) distances.append(gram_matrix.cpu().numpy()) return distances
def forward(self, graph, edge_feat, node_feat, g_repr): node_trf_func = lambda x: self.compute_node_repr( nodes=x, graph=graph, g_repr=g_repr) graph.edata['edge_feat'] = edge_feat graph.ndata['node_feat'] = node_feat edge_trf_func = lambda x: self.compute_edge_repr( edges=x, graph=graph, g_repr=g_repr) graph.apply_edges(edge_trf_func) graph.update_all(self.graph_message_func, self.graph_reduce_func, node_trf_func) e_comb = dgl.sum_edges(graph, 'edge_feat') n_comb = dgl.sum_nodes(graph, 'node_feat') e_out = graph.edata['edge_feat'] n_out = graph.ndata['node_feat'] e_keys = list(graph.edata.keys()) n_keys = list(graph.ndata.keys()) for key in e_keys: graph.edata.pop(key) for key in n_keys: graph.ndata.pop(key) return e_out, n_out, self.compute_u_repr(n_comb, e_comb, g_repr)
def forward(self, g, node_feats): r"""Computes graph representations out of node representations. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs. node_feats : float32 tensor of shape (V, node_in_feats) Input node features. V for the number of nodes in the batch of graphs. Returns ------- g_feats : float32 tensor of shape (G, node_in_feats) Output graph representations. G for the number of graphs in the batch. """ if self.gaussian_expand: node_feats = self.gaussian_histogram(node_feats) with g.local_scope(): g.ndata['h'] = node_feats g_feats = dgl.sum_nodes(g, 'h') if self.gaussian_expand: g_feats = self.to_out(g_feats) if self.activation is not None: g_feats = self.activation(g_feats) return g_feats
def forward(self, bg, feat): batch_size = bg.batch_size x = bg.ndata[feat] batch = tensor([], dtype=torch.int64) batch_num_nodes = bg.batch_num_nodes for index, num in enumerate(batch_num_nodes): batch = torch.cat((batch, tensor(index).expand(num))) h = (x.new_zeros((self.num_layers, batch_size, self.in_channels)), x.new_zeros((self.num_layers, batch_size, self.in_channels))) q_star = x.new_zeros(batch_size, self.out_channels) for i in range(self.processing_steps): q, h = self.lstm(q_star.unsqueeze(0), h) q = q.view(batch_size, self.in_channels) e = (x * q[batch]).sum(dim=-1, keepdim=True) a = torch.cat(list( map(lambda x: softmax(x, dim=0), list(torch.split(e, batch_num_nodes)))), dim=0) bg.ndata['w'] = a r = dgl.sum_nodes(bg, feat, 'w') q_star = torch.cat([q, r], dim=-1) return q_star
def forward(self, g, h, e, snorm_n, snorm_e, mlp=True, head=False, return_graph=False): h = self.embedding_h(h) h = self.in_feat_dropout(h) for conv in self.layers: h = conv(g, h, snorm_n) g.ndata['h'] = h if return_graph: return g if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes if mlp: return self.MLP_layer(hg) else: if head: return self.projection_head(hg) else: return hg
def forward(self, g, node_feats): """Computes graph representations out of node features. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs. node_feats : float32 tensor of shape (V, node_feats) Input node features, V for the number of nodes. Returns ------- graph_feats : float32 tensor of shape (G, graph_feats) Graph representations computed. G for the number of graphs. """ node_feats = self.in_project(node_feats) if self.activation is not None: node_feats = self.activation(node_feats) node_feats = self.out_project(node_feats) with g.local_scope(): g.ndata['h'] = node_feats if self.mode == 'max': graph_feats = dgl.max_nodes(g, 'h') elif self.mode == 'mean': graph_feats = dgl.mean_nodes(g, 'h') elif self.mode == 'sum': graph_feats = dgl.sum_nodes(g, 'h') return graph_feats
def forward(self, g, h, e): h = self.embedding_h(h) h = self.in_feat_dropout(h) for conv in self.layers: # For reduced graphs h = conv(g, h, e) # For original graphs # h = conv(g, h) g.ndata['h'] = h if self.readout == "sum": # For reduced graphs hg = dgl.sum_nodes(g, feat='h', weight='weight') # For original graphs # hg = dgl.sum_nodes(g, feat= 'h') elif self.readout == "max": # For reduced graphs hg = dgl.max_nodes(g, feat='h', weight='weight') # For original graphs # hg = dgl.max_nodes(g, feat= 'h') elif self.readout == "mean": # For reduced graphs hg = dgl.mean_nodes(g, feat='h', weight='weight') # For original graphs # hg = dgl.mean_nodes(g, feat= 'h') else: # For reduced graphs hg = dgl.mean_nodes( g, feat='h', weight='weight') # default readout is mean nodes # For original graphs # hg = dgl.mean_nodes(g, feat= 'h') return self.MLP_layer(hg)
def forward(self, g, h, e, pos_enc=None): # input embedding if self.pos_enc: h = self.embedding_pos_enc(pos_enc) else: h = self.embedding_h(h) # computing the 'pseudo' named tensor which depends on node degrees g.ndata['deg'] = g.in_degrees() g.apply_edges(self.compute_pseudo) pseudo = g.edata['pseudo'].to(self.device).float() for i in range(len(self.layers)): h = self.layers[i](g, h, self.pseudo_proj[i](pseudo)) g.ndata['h'] = h if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes return self.MLP_layer(hg)
def forward(self, g, h, e, snorm_n, snorm_e): h = self.embedding_h(h) h = self.in_feat_dropout(h) if self.edge_feat: e = self.embedding_e(e) # Loop all layers for i, conv in enumerate(self.layers): # Graph conv layers h_t = conv(g, h, e, snorm_n) h = h_t # Virtual node layer if self.virtual_node_layers is not None: if i == 0: vn_h = 0 if i < len(self.virtual_node_layers): vn_h, h = self.virtual_node_layers[i].forward(g, h, vn_h) g.ndata['h'] = h # Readout layer if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes return self.MLP_layer(hg)
def forward(self, g): self.embedding_layer(g, "node_0") if self.atom_ref is not None: self.e0(g, "e0") self.rbf_layer(g) self.edge_embedding_layer(g) for idx in range(self.n_conv): self.conv_layers[idx](g, idx + 1) node_embeddings = tuple(g.ndata["node_%d" % (i)] for i in range(self.n_conv + 1)) g.ndata["node"] = th.cat(node_embeddings, 1) # concat multilevel representations node = self.node_dense_layer1(g.ndata["node"]) node = self.activation(node) res = self.node_dense_layer2(node) g.ndata["res"] = res if self.atom_ref is not None: g.ndata["res"] = g.ndata["res"] + g.ndata["e0"] if self.norm: g.ndata["res"] = g.ndata[ "res"] * self.std_per_node + self.mean_per_node res = dgl.sum_nodes(g, "res") return res
def forward(self, g, node_feats, g_feats, get_node_weight=False): """ Parameters ---------- g : DGLGraph or BatchedDGLGraph Constructed DGLGraphs. node_feats : float32 tensor of shape (V, N1) Input node features. V for the number of nodes and N1 for the feature size. g_feats : float32 tensor of shape (G, N2) Input graph features. G for the number of graphs and N2 for the feature size. get_node_weight : bool Whether to get the weights of atoms during readout. Returns ------- float32 tensor of shape (G, N2) Updated graph features. float32 tensor of shape (V, 1) The weights of nodes in readout. """ with g.local_scope(): g.ndata['z'] = self.compute_logits( torch.cat([dgl.broadcast_nodes(g, F.relu(g_feats)), node_feats], dim=1)) g.ndata['a'] = dgl.softmax_nodes(g, 'z') g.ndata['hv'] = self.project_nodes(node_feats) context = F.elu(dgl.sum_nodes(g, 'hv', 'a')) if get_node_weight: return self.gru(context, g_feats), g.ndata['a'] else: return self.gru(context, g_feats)
def forward(self, graph, edge_feat, node_feat, g_repr, edge_hidden, node_hidden, graph_hidden): graph.edata['edge_feat'] = edge_feat graph.ndata['node_feat'] = node_feat graph.edata['hidden1'] = edge_hidden[0][0] graph.ndata['hidden1'] = node_hidden[0][0] graph.edata['hidden2'] = edge_hidden[1][0] graph.ndata['hidden2'] = node_hidden[1][0] node_trf_func = lambda x : self.compute_node_repr(nodes=x, graph=graph, g_repr=g_repr) edge_trf_func = lambda x: self.compute_edge_repr(edges=x, graph=graph, g_repr=g_repr) graph.apply_edges(edge_trf_func) graph.update_all(self.graph_message_func, self.graph_reduce_func, node_trf_func) e_comb = dgl.sum_edges(graph, 'edge_feat') n_comb = dgl.sum_nodes(graph, 'node_feat') u_out, u_hidden = self.compute_u_repr(n_comb, e_comb, g_repr, graph_hidden) e_feat = graph.edata['edge_feat'] n_feat = graph.ndata['node_feat'] h_e = (torch.unsqueeze(graph.edata['hidden1'],0),torch.unsqueeze(graph.edata['hidden2'],0)) h_n = (torch.unsqueeze(graph.ndata['hidden1'],0),torch.unsqueeze(graph.ndata['hidden2'],0)) e_keys = list(graph.edata.keys()) n_keys = list(graph.ndata.keys()) for key in e_keys: graph.edata.pop(key) for key in n_keys: graph.ndata.pop(key) return e_feat, h_e, n_feat, h_n, u_out, u_hidden
def forward(self, g): h = g.ndata['attr'] h = h.to(self.device) # list of hidden representation at each layer (including input) hidden_rep = [h] for layer in range(self.num_layers - 1): h = self.ginlayers[layer](g, h) hidden_rep.append(h) score_over_layer = 0 # perform pooling over all nodes in each graph in every layer for layer, h in enumerate(hidden_rep): g.ndata['h'] = h if self.graph_pooling_type == 'sum': pooled_h = dgl.sum_nodes(g, 'h') elif self.graph_pooling_type == 'mean': pooled_h = dgl.mean_nodes(g, 'h') elif self.graph_pooling_type == 'max': pooled_h = dgl.max_nodes(g, 'h') else: raise NotImplementedError() score_over_layer += F.dropout( self.linears_prediction[layer](pooled_h), self.final_dropout, training=self.training) return score_over_layer
def forward(self, g, h, e, h_lap_pos_enc=None, h_wl_pos_enc=None): # input embedding h = self.embedding_h(h) h = self.in_feat_dropout(h) if self.lap_pos_enc: h_lap_pos_enc = self.embedding_lap_pos_enc(h_lap_pos_enc.float()) h = h + h_lap_pos_enc if self.wl_pos_enc: h_wl_pos_enc = self.embedding_wl_pos_enc(h_wl_pos_enc) h = h + h_wl_pos_enc if not self.edge_feat: # edge feature set to 1 e = torch.ones(e.size(0), 1).to(self.device) e = self.embedding_e(e) # convnets for conv in self.layers: h, e = conv(g, h, e) g.ndata['h'] = h if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes return self.MLP_layer(hg)
def forward(self, g): # g_list list of molecules g.edata['distance'] = g.edata['distance'].reshape(-1, 1) self.embedding_layer(g) if self.atom_ref is not None: self.e0(g, "e0") self.rbf_layer(g) for idx in range(self.n_conv): self.conv_layers[idx](g) atom = self.atom_dense_layer1(g.ndata["node"]) atom = self.activation(atom) atom = self.activation(self.atom_dense_layer2(atom)) res = self.regressor(atom) g.ndata["res"] = res if self.atom_ref is not None: g.ndata["res"] = g.ndata["res"] + g.ndata["e0"] if self.norm: g.ndata["res"] = g.ndata[ "res"] * self.std_per_atom + self.mean_per_atom res = dgl.sum_nodes(g, "res") return res
def forward(self, g, h, e, snorm_n, snorm_e): h = self.embedding_h(h) h = self.in_feat_dropout(h) h_init = h '''for conv in self.layers: h = conv(g, h, snorm_n) h = self.joining_layer(h_init + h)''' for i in range(self.layer_count): conv = self.layers[i] joint = self.joining_layers[i] h = conv(g, h, snorm_n) h = joint(h_init + h) g.ndata['h'] = h if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes return self.MLP_layer(hg)
def forward(self, g, h, e, snorm_n, snorm_e): h = self.embedding_h(h) # computing the 'pseudo' named tensor which depends on node degrees us, vs = g.edges() # to avoid zero division in case in_degree is 0, we add constant '1' in all node degrees denoting self-loop pseudo = [[ 1 / np.sqrt(g.in_degree(us[i]) + 1), 1 / np.sqrt(g.in_degree(vs[i]) + 1) ] for i in range(g.number_of_edges())] pseudo = torch.Tensor(pseudo).to(self.device) for i in range(len(self.layers)): h = self.layers[i](g, h, self.pseudo_proj[i](pseudo), snorm_n) g.ndata['h'] = h if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes return self.MLP_layer(hg)