def forward(self, input_hidden, graphs: dgl.DGLGraph, batch_num_nodes=None): if batch_num_nodes is None: b_num_nodes = graphs.batch_num_nodes else: b_num_nodes = batch_num_nodes h_t = self.input_proj(input_hidden) # when there are no edges in the graph, there is nothing to do if graphs.number_of_edges() > 0: #give all the nodes an edges information about the current querry hidden state broadcasted_hn = dgl.broadcast_nodes(graphs, h_t) graphs.ndata['h_t'] = broadcasted_hn broadcasted_he = dgl.broadcast_edges(graphs, h_t) graphs.edata['h_t'] = broadcasted_he # create a copy of the node and edge states which will be updated for K iterations graphs.ndata['F_n_t'] = graphs.ndata['F_n'] graphs.edata['F_e_t'] = graphs.edata['F_e'] for _ in range(self.k_update_steps): graphs.ndata['s_n'] = self.object_score(torch.cat([graphs.ndata['h_t'], graphs.ndata['F_n_t']], dim=-1)) graphs.send(message_func=self.io_attention_send) graphs.recv(reduce_func=self.io_attention_reduce) graphs.ndata['F_n_t'] = graphs.ndata['F_i_tplus1'] if self.update_relations: graphs.edata['F_e_t'] = graphs.edata['F_e_tplus1'] io = torch.split(graphs.ndata['F_n_t'], split_size_or_sections=b_num_nodes) else: io = torch.split(graphs.ndata['F_n'], split_size_or_sections=b_num_nodes) io = pad_sequence(io, batch_first=True) io_mask = io.sum(dim=-1) != 0 return io, io_mask
def forward(self, g, node_feats, g_feats, get_node_weight=False): """ Parameters ---------- g : DGLGraph or BatchedDGLGraph Constructed DGLGraphs. node_feats : float32 tensor of shape (V, N1) Input node features. V for the number of nodes and N1 for the feature size. g_feats : float32 tensor of shape (G, N2) Input graph features. G for the number of graphs and N2 for the feature size. get_node_weight : bool Whether to get the weights of atoms during readout. Returns ------- float32 tensor of shape (G, N2) Updated graph features. float32 tensor of shape (V, 1) The weights of nodes in readout. """ with g.local_scope(): g.ndata['z'] = self.compute_logits( torch.cat([dgl.broadcast_nodes(g, F.relu(g_feats)), node_feats], dim=1)) g.ndata['a'] = dgl.softmax_nodes(g, 'z') g.ndata['hv'] = self.project_nodes(node_feats) context = F.elu(dgl.sum_nodes(g, 'hv', 'a')) if get_node_weight: return self.gru(context, g_feats), g.ndata['a'] else: return self.gru(context, g_feats)
def forward(self, graph, node_feat, edge_feat): if self.virtual_node: virtual_emb = self.virtual_emb.weight.expand(graph.batch_size, -1) hn = self.node_encoder(node_feat) for layer in range(self.num_layers): if self.virtual_node: # messages from virtual nodes to graph nodes virtual_hn = dgl.broadcast_nodes(graph, virtual_emb) hn = hn + virtual_hn he = self.edge_encoders[layer](edge_feat) hn = self.conv_layers[layer](graph, hn, he) if layer != self.num_layers - 1: hn = F.relu(hn) hn = self.dropout(hn) if self.virtual_node and layer != self.num_layers - 1: # messages from graph nodes to virtual nodes virtual_emb_tmp = self.virtual_pool(graph, hn) + virtual_emb virtual_emb = self.mlp_virtual[layer](virtual_emb_tmp) virtual_emb = self.dropout(F.relu(virtual_emb)) hg = self.pool(graph, hn) return self.pred(hg)
def forward(self, inputs, extra_inputs=None): KG_embeddings = super().forward(extra_inputs) uid, g = inputs iid = g.ndata['iid'] # (num_nodes,) feat_i = KG_embeddings['item'][iid] feat_u = KG_embeddings['user'][uid] feat = self.fc_i(feat_i) + dgl.broadcast_nodes(g, self.fc_u(feat_u)) feat_i = self.PSE_layer(g, feat) sr = th.cat([feat_i, feat_u], dim=1) logits = self.fc_sr(sr) @ self.item_embedding(self.item_indices).t() return logits
def forward(self, g, feat, last_nodes): with g.local_scope(): if self.batch_norm is not None: feat = self.batch_norm(feat) feat_u = self.fc_u(feat) feat_v = self.fc_v(feat[last_nodes]) feat_v = dgl.broadcast_nodes(g, feat_v) g.ndata['e'] = self.attn_e(th.sigmoid(feat_u + feat_v)) alpha = dgl.softmax_nodes(g, 'e') g.ndata['w'] = feat * alpha rst = dgl.sum_nodes(g, 'w') rst = self.fc_out(rst) return rst
def forward(self, g, x, edge_attr): ### virtual node embeddings for graphs virtualnode_embedding = self.virtualnode_embedding( torch.zeros(g.batch_size).to(x.dtype).to(x.device)) h_list = [self.atom_encoder(x)] batch_id = dgl.broadcast_nodes(g, torch.arange(g.batch_size).to(x.device)) for layer in range(self.num_layers): ### add message from virtual nodes to graph nodes h_list[layer] = h_list[layer] + virtualnode_embedding[batch_id] ### Message passing among graph nodes h = self.convs[layer](g, h_list[layer], edge_attr) h = self.batch_norms[layer](h) if layer == self.num_layers - 1: #remove relu for the last layer h = F.dropout(h, self.drop_ratio, training = self.training) else: h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) if self.residual: h = h + h_list[layer] h_list.append(h) ### update the virtual nodes if layer < self.num_layers - 1: ### add message from graph nodes to virtual nodes virtualnode_embedding_temp = self.pool(g, h_list[layer]) + virtualnode_embedding ### transform virtual nodes using MLP virtualnode_embedding_temp = self.mlp_virtualnode_list[layer]( virtualnode_embedding_temp) if self.residual: virtualnode_embedding = virtualnode_embedding + F.dropout( virtualnode_embedding_temp, self.drop_ratio, training = self.training) else: virtualnode_embedding = F.dropout( virtualnode_embedding_temp, self.drop_ratio, training = self.training) ### Different implementations of Jk-concat if self.JK == "last": node_representation = h_list[-1] elif self.JK == "sum": node_representation = 0 for layer in range(self.num_layers): node_representation += h_list[layer] return node_representation
def collate(samples): ''' collate function for building graph dataloader ''' # generate batched graphs and labels graphs, targets = map(list, zip(*samples)) batched_graph = dgl.batch(graphs) batched_targets = th.Tensor(targets) n_graphs = len(graphs) graph_id = th.arange(n_graphs) graph_id = dgl.broadcast_nodes(batched_graph, graph_id) batched_graph.ndata['graph_id'] = graph_id return batched_graph, batched_targets
def test_broadcast_nodes(): # test#1: basic g0 = dgl.DGLGraph(nx.path_graph(10)) feat0 = F.randn((40, )) ground_truth = F.stack([feat0] * g0.number_of_nodes(), 0) assert F.allclose(dgl.broadcast_nodes(g0, feat0), ground_truth) # test#2: batched graph g1 = dgl.DGLGraph(nx.path_graph(3)) g2 = dgl.DGLGraph() g3 = dgl.DGLGraph(nx.path_graph(12)) bg = dgl.batch([g0, g1, g2, g3]) feat1 = F.randn((40, )) feat2 = F.randn((40, )) feat3 = F.randn((40, )) ground_truth = F.stack( [feat0] * g0.number_of_nodes() +\ [feat1] * g1.number_of_nodes() +\ [feat2] * g2.number_of_nodes() +\ [feat3] * g3.number_of_nodes(), 0 ) assert F.allclose( dgl.broadcast_nodes(bg, F.stack([feat0, feat1, feat2, feat3], 0)), ground_truth)
def forward(self, g, feat, last_nodes): if self.batch_norm is not None: feat = self.batch_norm(feat) feat = self.feat_drop(feat) feat_u = self.fc_u(feat) feat_v = self.fc_v(feat[last_nodes]) feat_v = dgl.broadcast_nodes(g, feat_v) e = self.fc_e(th.sigmoid(feat_u + feat_v)) alpha = F.segment.segment_softmax(g.batch_num_nodes(), e) feat_norm = feat * alpha rst = F.segment.segment_reduce(g.batch_num_nodes(), feat_norm, 'sum') if self.fc_out is not None: rst = self.fc_out(rst) if self.activation is not None: rst = self.activation(rst) return rst
def collate(samples): ''' collate function for building the graph dataloader''' graphs, diff_graphs, labels = map(list, zip(*samples)) # generate batched graphs and labels batched_graph = dgl.batch(graphs) batched_labels = th.tensor(labels) batched_diff_graph = dgl.batch(diff_graphs) n_graphs = len(graphs) graph_id = th.arange(n_graphs) graph_id = dgl.broadcast_nodes(batched_graph, graph_id) batched_graph.ndata['graph_id'] = graph_id return batched_graph, batched_diff_graph, batched_labels
def forward(self, g, feat, last_nodes): if self.batch_norm is not None: feat = self.batch_norm(feat) if self.feat_drop is not None: feat = self.feat_drop(feat) feat_u = self.fc_u(feat) feat_v = self.fc_v(feat[last_nodes]) feat_v = dgl.broadcast_nodes(g, feat_v) e = self.fc_e(th.sigmoid(feat_u + feat_v)) # (num_nodes, 1) alpha = e * g.ndata['cnt'].view_as(e) rst = F.segment.segment_reduce(g.batch_num_nodes(), feat * alpha, 'sum') if self.fc_out is not None: rst = self.fc_out(rst) if self.activation is not None: rst = self.activation(rst) return rst
def forward(self, graph: dgl.DGLGraph, feat, lambda_max=None): shp = (len(graph.nodes()), ) + tuple(1 for _ in range(feat.dim() - 1)) with graph.local_scope(): norm = torch.pow(graph.in_degrees().float().clamp(min=1), -0.5).reshape(shp).to(feat.device) if lambda_max is None: try: lambda_max = laplacian_lambda_max(graph) except ArpackNoConvergence: lambda_max = [2.] * graph.batch_size if isinstance(lambda_max, list): lambda_max = torch.tensor(lambda_max).to(feat.device) if lambda_max.dim() < 1: lambda_max = lambda_max.unsqueeze(-1) # (B,) to (B, 1) # broadcast from (B, 1) to (N, 1) lambda_max = torch.reshape(broadcast_nodes(graph, lambda_max), shp).float() # T0(X) Tx_0 = feat rst = self.fc[0](Tx_0) # T1(X) if self._k > 1: graph.ndata['h'] = Tx_0 * norm graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) h = graph.ndata.pop('h') * norm # Λ = 2 * (I - D ^ -1/2 A D ^ -1/2) / lambda_max - I # = - 2(D ^ -1/2 A D ^ -1/2) / lambda_max + (2 / lambda_max - 1) I Tx_1 = -2. * h / lambda_max + Tx_0 * (2. / lambda_max - 1) rst = rst + self.fc[1](Tx_1) # Ti(x), i = 2...k for i in range(2, self._k): graph.ndata['h'] = Tx_1 * norm graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) h = graph.ndata.pop('h') * norm # Tx_k = 2 * Λ * Tx_(k-1) - Tx_(k-2) # = - 4(D ^ -1/2 A D ^ -1/2) / lambda_max Tx_(k-1) + # (4 / lambda_max - 2) Tx_(k-1) - # Tx_(k-2) Tx_2 = -4. * h / lambda_max + Tx_1 * (4. / lambda_max - 2) - Tx_0 rst = rst + self.fc[i](Tx_2) Tx_1, Tx_0 = Tx_2, Tx_1 # add bias if self.bias is not None: rst = rst + self.bias return rst
def test_broadcast(idtype, g): g = g.astype(idtype).to(F.ctx()) gfeat = F.randn((g.batch_size, 3)) # Test.0: broadcast_nodes g.ndata['h'] = dgl.broadcast_nodes(g, gfeat) subg = dgl.unbatch(g) for i, sg in enumerate(subg): assert F.allclose( sg.ndata['h'], F.repeat(F.reshape(gfeat[i], (1, 3)), sg.number_of_nodes(), dim=0)) # Test.1: broadcast_edges g.edata['h'] = dgl.broadcast_edges(g, gfeat) subg = dgl.unbatch(g) for i, sg in enumerate(subg): assert F.allclose( sg.edata['h'], F.repeat(F.reshape(gfeat[i], (1, 3)), sg.number_of_edges(), dim=0))
def forward(self, g, feat_i, feat_u, last_nodes): if self.batch_norm is not None: feat_i = self.batch_norm['item'](feat_i) feat_u = self.batch_norm['user'](feat_u) if self.feat_drop is not None: feat_i = self.feat_drop(feat_i) feat_u = self.feat_drop(feat_u) feat_val = feat_i feat_key = self.fc_key(feat_i) feat_u = self.fc_user(feat_u) feat_last = self.fc_last(feat_i[last_nodes]) feat_qry = dgl.broadcast_nodes(g, feat_u + feat_last) e = self.fc_e(th.sigmoid(feat_qry + feat_key)) # (num_nodes, 1) e = e + g.ndata['cnt'].log().view_as(e) alpha = F.segment.segment_softmax(g.batch_num_nodes(), e) rst = F.segment.segment_reduce(g.batch_num_nodes(), alpha * feat_val, 'sum') if self.activation is not None: rst = self.activation(rst) return rst
def forward(self, graphs, nodes_feat, edges_feat, nodes_num_norm_sqrt, edges_num_norm_sqrt): h = self.embedding_h(nodes_feat) h = self.in_feat_dropout(h) for conv in self.layers: h = conv(graphs, h, nodes_num_norm_sqrt) pass graphs.ndata['h'] = h h_mean = dgl.mean_nodes(graphs, 'h') h_mean = dgl.broadcast_nodes(graphs, h_mean) h_mean_and_h = torch.cat([h, h_mean], dim=-1) nodes_attention = 2 * torch.sigmoid( self.attention2(torch.relu(self.attention1(h_mean_and_h)))) - 1 h = h + h * nodes_attention graphs.ndata['h'] = h hg = dgl.mean_nodes(graphs, 'h') logits = self.readout_mlp(hg) return logits
def forward(self, graph, global_attr=None, out_node_key='h_v'): def recv_func(nodes): nodes_to_collect = [] num_nodes = nodes.data[self.node_key].shape[0] if self._use_nodes: nodes_to_collect.append(nodes.data[self.node_key]) # if self._use_sent_edges: # agg_edge_attr = getattr(torch, self._sent_edges_reducer)(nodes.mailbox["m"], dim=1) # nodes_to_collect.append(agg_edge_attr.expand(num_nodes, agg_edge_attr.shape[1])) if self._use_received_edges: agg_edge_attr = getattr(torch, self._received_edges_reducer)( nodes.mailbox["m"], dim=1) nodes_to_collect.append( agg_edge_attr.expand(num_nodes, agg_edge_attr.shape[1])) if self._use_globals and global_attr is not None: # self._global_attr = global_attr.unsqueeze(0) # make global_attr.shape = (1, DIM) # expanded_global_attr = self._global_attr.expand(num_nodes, self._global_attr.shape[1]) expanded_global_attr = nodes.data['expanded_global_attr'] nodes_to_collect.append(expanded_global_attr) collected_nodes = torch.cat(nodes_to_collect, dim=-1) if self.recurrent: return { out_node_key: self.net(collected_nodes, nodes.data[out_node_key]) } else: return {out_node_key: self.net(collected_nodes)} graph.ndata['expanded_global_attr'] = dgl.broadcast_nodes( graph, global_attr) if self._use_received_edges: graph.update_all(fn.copy_e(self.edge_key, "m"), recv_func) # trick else: graph.apply_nodes(recv_func) return graph
def forward(self, graph, feat): with graph.local_scope(): batch_size = graph.batch_size h = (feat.new_zeros((self.n_layers, batch_size, self.input_dim)), feat.new_zeros((self.n_layers, batch_size, self.input_dim)) ) #(6, 32, 100) q_star = feat.new_zeros(batch_size, self.output_dim) #(32, 200) #print(q_star.shape) for i in range(self.n_iters): q, h = self.lstm(q_star.unsqueeze(0), h) q = q.view(batch_size, self.input_dim) e = (feat * dgl.broadcast_nodes(graph, q)).sum(dim=-1, keepdim=True) graph.ndata['e'] = e alpha = dgl.softmax_nodes(graph, 'e') graph.ndata['r'] = feat * alpha readout = dgl.sum_nodes(graph, 'r') q_star = torch.cat([q, readout], dim=-1) return q_star
def forward(self, g, node_feats, g_feats, get_node_weight=False): """Perform one-step readout Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs. node_feats : float32 tensor of shape (V, node_feat_size) Input node features. V for the number of nodes. g_feats : float32 tensor of shape (G, graph_feat_size) Input graph features. G for the number of graphs. get_node_weight : bool Whether to get the weights of atoms during readout. Returns ------- float32 tensor of shape (G, graph_feat_size) Updated graph features. float32 tensor of shape (V, 1) The weights of nodes in readout. """ with g.local_scope(): g.ndata['z'] = self.compute_logits( torch.cat([dgl.broadcast_nodes(g, F.relu(g_feats)), node_feats], dim=1)) g.ndata['a'] = dgl.softmax_nodes(g, 'z') g.ndata['hv'] = self.project_nodes(node_feats) if isinstance(g, BatchedDGLGraph): g_repr = dgl.sum_nodes(g, 'hv', 'a') else: g_repr = dgl.sum_nodes(g, 'hv', 'a').unsqueeze(0) context = F.elu(g_repr) if get_node_weight: return self.gru(context, g_feats), g.ndata['a'] else: return self.gru(context, g_feats)
def forward(self, graph: dgl.DGLGraph, feat: torch.Tensor) -> torch.Tensor: """ Compute set2set pooling. Args: graph: the input graph feat: The input feature with shape :math:`(N, D)` where :math:`N` is the number of nodes in the graph, and :math:`D` means the size of features. Returns: The output feature with shape :math:`(B, D)`, where :math:`B` refers to the batch size, and :math:`D` means the size of features. """ with graph.local_scope(): batch_size = graph.batch_size h = ( feat.new_zeros((self.n_layers, batch_size, self.input_dim)), feat.new_zeros((self.n_layers, batch_size, self.input_dim)), ) q_star = feat.new_zeros(batch_size, self.output_dim) for _ in range(self.n_iters): q, h = self.lstm(q_star.unsqueeze(0), h) q = q.view(batch_size, self.input_dim) e = (feat * dgl.broadcast_nodes(graph, q, ntype=self.ntype)).sum( dim=-1, keepdim=True) graph.nodes[self.ntype].data["e"] = e alpha = dgl.softmax_nodes(graph, "e", ntype=self.ntype) graph.nodes[self.ntype].data["r"] = feat * alpha readout = dgl.sum_nodes(graph, "r", ntype=self.ntype) q_star = torch.cat([q, readout], dim=-1) return q_star
def forward(self, g, feature, e): h_in = feature # to be used for residual connection lambda_max = [2] * g.batch_size def unnLaplacian(feature, D_sqrt, graph): """ Operation D^-1/2 A D^-1/2 """ graph.ndata['h'] = feature * D_sqrt graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) return graph.ndata.pop('h') * D_sqrt with g.local_scope(): D_sqrt = torch.pow(g.in_degrees().float().clamp( min=1), -0.5).unsqueeze(-1).to(feature.device) lambda_max = [2] * g.batch_size if lambda_max is None: try: lambda_max = dgl.laplacian_lambda_max(g) except BaseException: # if the largest eigonvalue is not found lambda_max = [2] if isinstance(lambda_max, list): lambda_max = torch.Tensor(lambda_max).to(feature.device) if lambda_max.dim() == 1: lambda_max = lambda_max.unsqueeze(-1) # (B,) to (B, 1) # broadcast from (B, 1) to (N, 1) lambda_max = dgl.broadcast_nodes(g, lambda_max) # X_0(f) Xt = X_0 = feature # X_1(f) if self._k > 1: re_norm = (2. / lambda_max).to(feature.device) h = unnLaplacian(X_0, D_sqrt, g) # print('h',h,'norm',re_norm,'X0',X_0) X_1 = - re_norm * h + X_0 * (re_norm - 1) Xt = torch.cat((Xt, X_1), 1) # Xi(x), i = 2...k for _ in range(2, self._k): h = unnLaplacian(X_1, D_sqrt, g) X_i = - 2 * re_norm * h + X_1 * 2 * (re_norm - 1) - X_0 Xt = torch.cat((Xt, X_i), 1) X_1, X_0 = X_i, X_1 h = self.linear(Xt) if self.batch_norm: h = self.batchnorm_h(h) # batch normalization if self.activation: h = self.activation(h) if self.residual: h = h_in + h # residual connection h = self.dropout(h) return h, e
def forward(self, g, node_feats, edge_feats): """Update node representations. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs node_feats : LongTensor of shape (N, 1) Input categorical node features. N for the number of nodes. edge_feats : FloatTensor of shape (E, in_edge_feats) Input edge features. E for the number of edges. Returns ------- FloatTensor of shape (N, hidden_feats) Output node representations """ if self.gnn_type == 'gcn': degs = (g.in_degrees().float() + 1).to(node_feats.device) norm = torch.pow(degs, -0.5).unsqueeze(-1) # (N, 1) g.ndata['norm'] = norm g.apply_edges(fn.u_mul_v('norm', 'norm', 'norm')) norm = g.edata.pop('norm') if self.virtual_node: virtual_node_feats = self.virtual_node_emb( torch.zeros(g.batch_size).to(node_feats.dtype).to( node_feats.device)) h_list = [self.node_encoder(node_feats)] for l in range(len(self.layers)): if self.virtual_node: virtual_feats_broadcast = dgl.broadcast_nodes( g, virtual_node_feats) h_list[l] = h_list[l] + virtual_feats_broadcast if self.gnn_type == 'gcn': h = self.layers[l](g, h_list[l], edge_feats, degs, norm) else: h = self.layers[l](g, h_list[l], edge_feats) if self.batchnorms is not None: h = self.batchnorms[l](h) if self.activation is not None and l != self.n_layers - 1: h = self.activation(h) h = self.dropout(h) h_list.append(h) if l < self.n_layers - 1 and self.virtual_node: ### Update virtual node representation from real node representations virtual_node_feats_tmp = self.virtual_readout( g, h_list[l]) + virtual_node_feats if self.residual: virtual_node_feats = virtual_node_feats + self.dropout( self.mlp_virtual_project[l](virtual_node_feats_tmp)) else: virtual_node_feats = self.dropout( self.mlp_virtual_project[l](virtual_node_feats_tmp)) if self.jk: return torch.stack(h_list, dim=0).sum(0) else: return h_list[-1]
def forward(self, graph, feat, lambda_max=None): r""" Description ----------- Compute ChebNet layer. Parameters ---------- graph : DGLGraph The graph. feat : torch.Tensor The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. lambda_max : list or tensor or None, optional. A list(tensor) with length :math:`B`, stores the largest eigenvalue of the normalized laplacian of each individual graph in ``graph``, where :math:`B` is the batch size of the input graph. Default: None. If None, this method would compute the list by calling ``dgl.laplacian_lambda_max``. Returns ------- torch.Tensor The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` is size of output feature. """ def unnLaplacian(feat, D_invsqrt, graph): """ Operation Feat * D^-1/2 A D^-1/2 但是如果写成矩阵乘法:D^-1/2 A D^-1/2 Feat""" graph.ndata['h'] = feat * D_invsqrt graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) return graph.ndata.pop('h') * D_invsqrt with graph.local_scope(): #一点修改,这是原来的代码 if self.is_mnist: graph.update_all(fn.copy_edge('v', 'm'), fn.sum('m', 'h')) # 'v'与coordinate.py有关 D_invsqrt = th.pow( graph.ndata.pop('h').float().clamp(min=1), -0.5).unsqueeze(-1).to(feat.device) #D_invsqrt = th.pow(graph.in_degrees().float().clamp( # min=1), -0.5).unsqueeze(-1).to(feat.device) #print("in_degree : ",graph.in_degrees().shape) else: D_invsqrt = th.pow(graph.in_degrees().float().clamp(min=1), -0.5).unsqueeze(-1).to(feat.device) #print("D_invsqrt : ",D_invsqrt.shape) #print("ndata : ",graph.ndata['h'].shape) if lambda_max is None: try: lambda_max = laplacian_lambda_max(graph) except BaseException: # if the largest eigenvalue is not found dgl_warning( "Largest eigonvalue not found, using default value 2 for lambda_max", RuntimeWarning) lambda_max = th.Tensor(2).to(feat.device) if isinstance(lambda_max, list): lambda_max = th.Tensor(lambda_max).to(feat.device) if lambda_max.dim() == 1: lambda_max = lambda_max.unsqueeze(-1) # (B,) to (B, 1) # broadcast from (B, 1) to (N, 1) lambda_max = broadcast_nodes(graph, lambda_max) re_norm = 2. / lambda_max # X_0 is the raw feature, Xt refers to the concatenation of X_0, X_1, ... X_t Xt = X_0 = feat # X_1(f) if self._k > 1: h = unnLaplacian(X_0, D_invsqrt, graph) X_1 = -re_norm * h + X_0 * (re_norm - 1) # Concatenate Xt and X_1 Xt = th.cat((Xt, X_1), 1) # Xi(x), i = 2...k for _ in range(2, self._k): h = unnLaplacian(X_1, D_invsqrt, graph) X_i = -2 * re_norm * h + X_1 * 2 * (re_norm - 1) - X_0 # Concatenate Xt and X_i Xt = th.cat((Xt, X_i), 1) X_1, X_0 = X_i, X_1 # linear projection h = self.linear(Xt) # activation if self.activation: h = self.activation(h) #print('ChebConv.py Line163 h : ',h.shape) return h