def forward(self, x, edge_index, edge_attr, self_loop_index=None, self_loop_type=None): edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) if self_loop_index is not None: self_loop_attr = torch.zeros(x.size(0), edge_attr.size(1)) self_loop_attr[:, self_loop_index] = self_loop_type self_loop_attr = self_loop_attr.to(edge_attr.device).to( edge_attr.dtype) self_loop_attr.to(x.device) edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0) if self.edge_embeddings is not None: for i in range(edge_index.shape[0]): self.edge_embeddings[i].to(x.device) edge_embeddings = sum([ self.edge_embeddings[i](edge_attr[:, i]) for i in range(edge_index.shape[0]) ]) elif self.edge_encoder is not None: edge_embeddings = self.edge_encoder(edge_attr) else: raise NotImplementedError if self.input_node_embeddings is not None: x = self.input_node_embeddings(x.long().view(-1)) if self.feature_concat: h = torch.cat((x[edge_index[1]], edge_embeddings), dim=1) else: h = x[edge_index[1]] + edge_embeddings h = self.aggr(h, edge_index, x.size(0)) h = self.mlp(h) return h
def forward(self, x, edge_index): """""" edge_index, _ = remove_self_loops(edge_index) edge_index = add_self_loops(edge_index, num_nodes=x.size(0)) x = torch.mm(x, self.weight).view(-1, self.heads, self.out_channels) return self.propagate(edge_index, x=x, num_nodes=x.size(0))
def __call__(self, data): N = data.num_nodes edge_index = data.edge_index assert data.edge_attr is None edge_index = add_self_loops(edge_index, num_nodes=N) edge_index, _ = coalesce(edge_index, None, N, N) data.edge_index = edge_index return data
def bingge_norm_adj(adj, adj_values, num_nodes): adj, adj_values = add_self_loops(adj, adj_values, 1, num_nodes) deg = spmm(adj, adj_values, torch.ones(num_nodes, 1).to(adj.device)).squeeze() deg_sqrt = deg.pow(-1 / 2) adj_values = deg_sqrt[adj[1]] * adj_values * deg_sqrt[adj[0]] row, col = adj[0], adj[1] mask = row != col adj_values[row[mask]] += 1 return adj, adj_values
def norm(edge_index, num_nodes, edge_weight, gcn=False, dtype=None): if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) edge_weight = edge_weight.view(-1) assert edge_weight.size(0) == edge_index.size(1) edge_index, _ = remove_self_loops(edge_index) edge_index = add_self_loops(edge_index, num_nodes) loop_weight = torch.full((num_nodes, ), 1 if gcn else 0, dtype=edge_weight.dtype, device=edge_weight.device) edge_weight = torch.cat([edge_weight, loop_weight], dim=0) row, col = edge_index deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow(-1) # deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 return edge_index, deg_inv_sqrt[row] * edge_weight