def propagate(self, edge_index, x, norm): # type: (Tensor, Tensor, Tensor) -> Tensor x_j = torch.index_select(x, 0, edge_index[1]) out = self.message(x_j, norm) out = scatter_add(out, edge_index[0], 0, None, dim_size=x.size(0)) out = self.update(out) return out
def forward(self, x, edge_index, edge_type, edge_norm): # type: (Optional[Tensor], Tensor, Tensor, Optional[Tensor]) -> Tensor w = torch.matmul(self.att, self.basis.view(self.n_bases, -1)) if x is None: w = w.view(-1, self.out_h) index = edge_type * self.in_h + edge_index[1] out = torch.index_select(w, 0, index) else: x_j = torch.index_select(x, 0, edge_index[1]) w = w.view(self.n_relations, self.in_h, self.out_h) w = torch.index_select(w, 0, edge_type) out = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2) if edge_norm is not None: out = out * edge_norm.view(-1, 1) out = scatter_add(out, edge_index[0], dim=0) if x is None: out = out + self.root else: out = out + torch.matmul(x[0:out.size(0)], self.root) out = out + self.bias return out
def softmax(src, index, num_nodes): # type: (Tensor, Tensor, int) -> Tensor num_nodes = maybe_num_nodes(index, num_nodes) out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[index] out = out.exp() / ( scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16) return out
def propagate(self, edge_index, x, norm): # type: (Tensor, Tensor, Tensor) -> Tensor x_j = torch.index_select(x, 0, edge_index[1]) out = self.message(x_j, norm) out = scatter_add( out, edge_index[0], 0, None) # do not set dim_size, out.size() = edge_index[1].max() + 1 out = self.update(out) return out
def spmm(index, value, m, matrix): # type: (Tensor, Tensor, int, Tensor) -> Tensor row, col = index[0], index[1] matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1) out = matrix[col] out = out * value.unsqueeze(-1) out = scatter_add(out, row, dim=0, dim_size=m) return out
def propagate(self, edge_index, x, num_nodes): # type: (Tensor, Tensor, int) -> Tensor x_i = torch.index_select(x, 0, edge_index[1]) x_j = torch.index_select(x, 0, edge_index[0]) edge_index_i = edge_index[1] out = self.message(edge_index_i, x_i, x_j, num_nodes) out = scatter_add(out, edge_index[1], dim_size=x.size(0), dim=0) out = self.update(out) # out size: num_nodes * heads * out_channels # TODO: support multi-heads condition return out.squeeze(1)
def norm(self, edge_index, num_nodes): # type: (Tensor, int) -> Tuple[Tensor, Tensor] edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device) fill_value = 1 edge_index, edge_weight = add_remaining_self_loops( edge_index, edge_weight, fill_value, num_nodes) row, col = edge_index[0], edge_index[1] deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow(-0.5) return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def norm(self, edge_index, num_nodes): # type: (Tensor, int) -> Tuple[Tensor, Tensor] edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device) row, col = edge_index[0], edge_index[1] deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow(-0.5) mask = deg_inv_sqrt == float('inf') if mask is not None: # deg_inv_sqrt[mask] = 0 deg_inv_sqrt.masked_fill_(mask, 0) return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def __merge_edges__(self, x, edge_index, edge_score): nodes_remaining = set(range(x.shape[0])) cluster = np.zeros(x.shape[0], dtype=edge_index.dtype) # argsort in ascending order and reverse edge_argsort = np.argsort(edge_score)[::-1] # Iterate through all edges i = 0 new_edge_indices = [] for edge_idx in edge_argsort.tolist(): source = edge_index[0, edge_idx] target = edge_index[1, edge_idx] if source in nodes_remaining and target in nodes_remaining: # contract the edge if it is not incident to a chosen node new_edge_indices.append(edge_idx) cluster[source] = i nodes_remaining.remove(source) if source != target: cluster[target] = i nodes_remaining.remove(target) i += 1 else: continue # The remaining nodes are simply kept. for node_idx in nodes_remaining: cluster[node_idx] = i i += 1 # We compute the new features as an addition of the old ones. new_num_nodes = np.max(cluster) + 1 new_x = np.zeros((new_num_nodes, x.shape[1]), dtype=x.dtype) new_x = scatter_add(new_x, cluster, x) N = new_x.shape[0] new_edge_index = coalesce(cluster[edge_index], None, N, N) new_edge_index = np.array(new_edge_index.nonzero(), dtype=edge_index.dtype) return new_x, new_edge_index
def forward(self, x, edge_index, edge_weight=None): # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) row, col = edge_index[0], edge_index[1] num_nodes, num_edges, K = x.size(0), row.size(0), self.weight.size(0) if edge_weight is None: edge_weight = torch.ones((num_edges,), dtype=x.dtype, device=edge_index.device) edge_weight = edge_weight.view(-1) deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) # Compute normalized and rescaled graph Laplacian deg = deg.pow(-0.5) lap = -deg[row] * edge_weight * deg[col] # Perform filter operation recurrently Tx_0 = x out = torch.mm(Tx_0, self.weight[0]) Tx_1 = spmm(edge_index, lap, num_nodes, x) if K > 1: out = out + torch.mm(Tx_1, self.weight[1]) for k in range(K): if k >= 2: Tx_2 = 2 * spmm(edge_index, lap, num_nodes, Tx_1) - Tx_0 out = out + torch.mm(Tx_2, self.weight[k]) Tx_0, Tx_1 = Tx_1, Tx_2 if self.bias is not None: out = out + self.bias return out
def predict(valid_id, v): scores = [] for y in range(n_label): scores.append(scatter_add(v[y], valid_id)) return np.argmax(scores)
def edge_softmax(src, index, num_nodes): src = np.exp(src) norm = np.zeros(num_nodes) norm = scatter_add(norm, index, src) out = src / (norm[index] + 1e-16) return out