def forward(self, x, edge_index): edge_index, _ = remove_self_loops(edge_index) # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) out = self.fea_mlp(self.propagate(edge_index, x=x)) if self.BN is not None: out = self.BN(out) return out
def get_icosahedron_weights(nodes, depth): """Get the icosahedron laplacian list for a certain depth. Args: nodes (int): initial number of nodes. depth (int): the depth of the UNet. laplacian_type ["combinatorial", "normalized"]: the type of the laplacian. Returns: laps (list): increasing list of laplacians. """ edge_list = [] weight_list = [] order = icosahedron_order_calculator(nodes) for _ in range(depth): nodes = icosahedron_nodes_calculator(order) order_initial = icosahedron_order_calculator(nodes) coords = get_ico_coords(int(order_initial)) coords = torch.from_numpy(coords) edge_index = knn_graph(coords, 6 if order else 5) if order: dist = torch.norm(coords[edge_index[0]] - coords[edge_index[1]], p=2, dim=1) _, extra_idx = torch.topk(dist, 12) edge_index[0, extra_idx] = edge_index[1, extra_idx] edge_index, _ = remove_self_loops(edge_index) edge_list.append(edge_index) weight_list.append(None) order -= 1 return edge_list[::-1], weight_list
def forward(self, x, edge_index): x = x.unsqueeze(-1) if x.dim() == 1 else x x = torch.mm(x, self.weight) x = x.view(-1, self.heads, self.out_channels) # Add self-loops to adjacency matrix. edge_index, edge_attr = remove_self_loops(edge_index) edge_index = add_self_loops(edge_index, num_nodes=x.size(0)) row, col = edge_index # Compute attention coefficients. alpha = torch.cat([x[row], x[col]], dim=-1) alpha = (alpha * self.att_weight).sum(dim=-1) alpha = F.leaky_relu(alpha, self.negative_slope) alpha = softmax(alpha, row, num_nodes=x.size(0)) # Sample attention coefficients stochastically. dropout = self.dropout if self.training else 0 alpha = F.dropout(alpha, p=dropout, training=True) # Sum up neighborhoods. out = alpha.view(-1, self.heads, 1) * x[col] out = scatter_add(out, row, dim=0, dim_size=x.size(0)) if self.concat is True: out = out.view(-1, self.heads * self.out_channels) else: out = out.sum(dim=1) / self.heads if self.bias is not None: out += self.bias return out
def load_node_pair(adj, file_path, n_hop=10): from os.path import join as pjoin if file_path and os.path.exists(pjoin(file_path, 'pos_node_edge.pt')): pos_neighbor = torch.load(pjoin(file_path, 'pos_node_edge.pt')) neg_neighbor = torch.load(pjoin(file_path, 'neg_node_edge.pt')) else: adj = adj.to('cuda:8') adj = adj > 0 neighbor = torch.mm(adj.float(), adj.float().t()) > 0 size = neighbor.size(1) pos_neighbor = torch.zeros_like(neighbor) sort_weight, indices = torch.sum(neighbor.float(), dim=0).sort() min_count, max_count = int(size * 0.1), int(size * 0.9) min_pos, max_pos = sort_weight[min_count], sort_weight[max_count] min_count = (sort_weight < max(min_pos, 2)).sum() max_count = (sort_weight <= max_pos).sum() pos_select = indices[min_count:max_count] pos_neighbor[:, pos_select] = neighbor[:, pos_select] neg_neighbor = neighbor for i in range(1, n_hop): neg_neighbor = torch.mm(neg_neighbor.float(), neighbor.float()) > 0 neg_neighbor = ~neg_neighbor pos_neighbor, _ = dense_to_sparse(pos_neighbor) pos_neighbor = remove_self_loops(pos_neighbor)[0] neg_neighbor, _ = dense_to_sparse(neg_neighbor) pos_neighbor = pos_neighbor.cpu() neg_neighbor = neg_neighbor.cpu() if file_path: torch.save(pos_neighbor, pjoin(file_path, 'pos_node_edge.pt')) torch.save(neg_neighbor, pjoin(file_path, 'neg_node_edge.pt')) return pos_neighbor, neg_neighbor
def process(self): with open(self.raw_paths[0], 'r') as f: graph_data = json.load(f) mask = torch.zeros(len(graph_data['nodes']), dtype=torch.uint8) for i in graph_data['nodes']: mask[i['id']] = 1 if i['val'] else (2 if i['test'] else 0) train_mask, val_mask, test_mask = mask == 0, mask == 1, mask == 2 row, col = [], [] for i in graph_data['links']: row.append(i['source']) col.append(i['target']) edge_index = torch.stack([torch.tensor(row), torch.tensor(col)], dim=0) edge_index, _ = remove_self_loops(edge_index) edge_index, _ = coalesce(edge_index, num_nodes=mask.size(0)) x = torch.from_numpy(np.load(self.raw_paths[1])).float() with open(self.raw_paths[2], 'r') as f: y_data = json.load(f) y = [] for i in range(len(y_data)): y.append(y_data[str(i)]) y = torch.tensor(y, dtype=torch.float) data = Data(x=x, edge_index=edge_index, y=y) data.train_mask = train_mask data.val_mask = val_mask data.test_mask = test_mask data = data if self.pre_transform is None else self.pre_transform(data) data, slices = self.collate([data]) torch.save((data, slices), self.processed_paths[0])
def forward(self, x, edge_index, size=None, return_attention_weights=False): """""" if size is None and torch.is_tensor(x): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(self.node_dim)) if torch.is_tensor(x): x = torch.matmul(x, self.weight) else: x = (None if x[0] is None else torch.matmul(x[0], self.weight), None if x[1] is None else torch.matmul(x[1], self.weight)) out = self.propagate(edge_index, size=size, x=x, return_attention_weights=return_attention_weights) if return_attention_weights: alpha, self.alpha = self.alpha, None return out, alpha else: return out
def pool_edge(cluster, edge_index, edge_attr=None): num_nodes = cluster.size(0) edge_index = cluster[edge_index.view(-1)].view(2, -1) edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes, num_nodes) return edge_index, edge_attr
def __call__(self, data): edge_index, edge_attr = data.edge_index, data.edge_attr n = data.num_nodes fill = 1e16 value = edge_index.new_full((edge_index.size(1), ), fill, dtype=torch.float) index, value = spspmm(edge_index, value, edge_index, value, n, n, n) index, value = remove_self_loops(index, value) edge_index = torch.cat([edge_index, index], dim=1) if edge_attr is None: data.edge_index, _ = coalesce(edge_index, None, n, n) else: value = value.view(-1, *[1 for _ in range(edge_attr.dim() - 1)]) value = value.expand(-1, *list(edge_attr.size())[1:]) edge_attr = torch.cat([edge_attr, value], dim=0) data.edge_index, edge_attr = coalesce(edge_index, edge_attr, n, n, op='min', fill_value=fill) edge_attr[edge_attr >= fill] = 0 data.edge_attr = edge_attr return data
def __init__(self, name, num_fixed_features, use_rand_feats): # Write path self.path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name) self.edges, self.objects, self.weights = load_edge_list( self.path, False) num_nodes = len(self.objects) self.feats = torch.randn(num_nodes, num_fixed_features, requires_grad=False) perm = torch.randperm(self.feats.size(0)) perm_idx = perm[:num_nodes] feats = self.feats[perm_idx] G = nx.Graph() G.add_edges_from(self.edges) edge_index = torch.tensor(list(G.edges)).t().contiguous() edge_index, _ = remove_self_loops(edge_index) if use_rand_feats: self.num_features = num_fixed_features self.dataset = Data(edge_index=edge_index, x=feats) else: adj_mat = torch.Tensor(nx.to_numpy_matrix(G)) self.dataset = Data(edge_index=edge_index, x=adj_mat) self.num_features = num_nodes self.reconstruction_loss = None
def structured_negative_sampling_feasible( edge_index: Tensor, num_nodes: Optional[int] = None, contains_neg_self_loops: bool = True) -> bool: r"""Returns :obj:`True` if :meth:`~torch_geometric.utils.structured_negative_sampling` is feasible on the graph given by :obj:`edge_index`. :obj:`~torch_geometric.utils.structured_negative_sampling` is infeasible if atleast one node is connected to all other nodes. Args: edge_index (LongTensor): The edge indices. num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) contains_neg_self_loops (bool, optional): If set to :obj:`False`, sampled negative edges will not contain self loops. (default: :obj:`True`) :rtype: bool """ num_nodes = maybe_num_nodes(edge_index, num_nodes) max_num_neighbors = num_nodes edge_index = coalesce(edge_index, num_nodes=num_nodes) if not contains_neg_self_loops: edge_index, _ = remove_self_loops(edge_index) max_num_neighbors -= 1 # Reduce number of valid neighbors deg = degree(edge_index[0], num_nodes) # True if there exists no node that is connected to all other nodes. return bool(torch.all(deg < max_num_neighbors))
def recon_loss1(self, z, edge_index, batch): EPS = 1e-15 MAX_LOGSTD = 10 r"""Given latent variables :obj:`z`, computes the binary cross entropy loss for positive edges :obj:`pos_edge_index` and negative sampled edges. Args: z (Tensor): The latent space :math:`\mathbf{Z}`. pos_edge_index (LongTensor): The positive edges to train against. """ recon_adj = self.edge_recon(z, edge_index) pos_loss = -torch.log( recon_adj + EPS).mean() # Do not include self-loops in negative samples pos_edge_index, _ = remove_self_loops(edge_index) pos_edge_index, _ = add_self_loops(pos_edge_index) neg_edge_index = negative_sampling(pos_edge_index, z.size(0)) #random thingggg neg_loss = -torch.log(1 - self.edge_recon(z, neg_edge_index) + EPS).mean() return pos_loss + neg_loss
def forward(self, x, edge_index, size=None): if size is None and torch.is_tensor(x): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(self.node_dim)) self.cache["edge_index"] = edge_index if torch.is_tensor(x): x = torch.matmul(x, self.weight) else: x = (None if x[0] is None else torch.matmul(x[0], self.weight), None if x[1] is None else torch.matmul(x[1], self.weight)) propagated = self.propagate(edge_index, size=size, x=x) if self.training: att_neg_list = [] for _ in range(self.num_neg_samples_per_edge): edge_j, edge_i, edge_k = structured_negative_sampling( edge_index=edge_index, num_nodes=x.size(0), ) x_j, x_k = x[edge_j], x[edge_k] att_neg = self.get_unnormalized_attention(x_j, x_k) att_neg_list.append(att_neg) self.cache["att_neg"] = torch.stack(att_neg_list, dim=-1) # [E, heads, num_neg] return propagated
def test_attr(): model.eval() accs = [] m = ['train_mask', 'val_mask', 'test_mask'] i = 0 for _, mask in data('train_mask', 'val_mask', 'test_mask'): if (m[i] == 'train_mask') : x, pos_edge_index = data.x, data.train_pos_edge_index _edge_index, _ = remove_self_loops(pos_edge_index) pos_edge_index_with_self_loops, _ = add_self_loops(_edge_index, num_nodes=x.size(0)) neg_edge_index = negative_sampling( edge_index=pos_edge_index_with_self_loops, num_nodes=x.size(0), num_neg_samples=pos_edge_index.size(1)) else: pos_edge_index, neg_edge_index = [ index for _, index in data("{}_pos_edge_index".format(m[i].split("_")[0]), "{}_neg_edge_index".format(m[i].split("_")[0])) ] _, logits, _ = model(pos_edge_index, neg_edge_index) pred = logits[mask].max(1)[1] macro = f1_score((data.y[mask]).cpu().numpy(), pred.cpu().numpy(),average='macro') accs.append(macro) i+=1 return accs
def forward(self, x, pos, edge_index): edge_index, _ = remove_self_loops(edge_index) self.define_message(x) if self.x_is_none: return self.propagate(edge_index, pos=pos) else: return self.propagate(edge_index, x=x, pos=pos[0])
def __norm__(self, edge_index, num_nodes: Optional[int], edge_weight: OptTensor, normalization: Optional[str], lambda_max, dtype: Optional[int] = None, batch: OptTensor = None): edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) edge_index, edge_weight = get_laplacian(edge_index, edge_weight, normalization, dtype, num_nodes) if batch is not None and lambda_max.numel() > 1: lambda_max = lambda_max[batch[edge_index[0]]] edge_weight = (2.0 * edge_weight) / lambda_max edge_weight.masked_fill_(edge_weight == float('inf'), 0) edge_index, edge_weight = add_self_loops(edge_index, edge_weight, fill_value=-1., num_nodes=num_nodes) assert edge_weight is not None return edge_index, edge_weight
def forward(self, x, edge_index, edge_attr, weight, bias, edge_encoder_w, edge_encoder_b, size=None): """""" if size is None and torch.is_tensor(x): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) # add features to self-loop edges self_loop_attr = torch.zeros(x.size(0), self.edge_in_channels) self_loop_attr[:, 7] = 1 # for bio dataset self_loop_attr = self_loop_attr.to(edge_attr.device).to( edge_attr.dtype) edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0) edge_emb = F.linear(edge_attr, edge_encoder_w, edge_encoder_b) x = torch.matmul(x, weight) return self.propagate(edge_index, size=size, x=x, bias=bias, edge_attr=edge_emb)
def forward(self, x: Union[OptTensor, PairOptTensor], pos: Union[Tensor, PairTensor], normal: Union[Tensor, PairTensor], edge_index: Adj) -> Tensor: # yapf: disable """""" if not isinstance(x, tuple): x: PairOptTensor = (x, None) if isinstance(pos, Tensor): pos: PairTensor = (pos, pos) if isinstance(normal, Tensor): normal: PairTensor = (normal, normal) if self.add_self_loops: if isinstance(edge_index, Tensor): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=pos[1].size(0)) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: PairOptTensor, pos: PairTensor, normal: PairTensor) # noqa out = self.propagate(edge_index, x=x, pos=pos, normal=normal, size=None) if self.global_nn is not None: out = self.global_nn(out) return out
def forward(self, x, edge_index, size=None, edge_weights=None, return_alpha=False): self.return_alpha = return_alpha if size is None and torch.is_tensor(x): edge_index, edge_weights = remove_self_loops( edge_index, edge_weights) edge_index, edge_weights = add_self_loops(edge_index, edge_weight=edge_weights, num_nodes=x.size( self.node_dim)) self.edge_weights = edge_weights if torch.is_tensor(x): x = torch.matmul(x, self.weight) else: x = (None if x[0] is None else torch.matmul(x[0], self.weight), None if x[1] is None else torch.matmul(x[1], self.weight)) if self.return_alpha: return self.propagate(edge_index, size=size, x=x), self.alpha, edge_index return self.propagate(edge_index, size=size, x=x)
def barabasi_albert_graph(num_nodes, num_edges): r"""Returns the :obj:`edge_index` of a Barabasi-Albert preferential attachment model, where a graph of :obj:`num_nodes` nodes grows by attaching new nodes with :obj:`num_edges` edges that are preferentially attached to existing nodes with high degree. Args: num_nodes (int): The number of nodes. num_edges (int): The number of edges from a new node to existing nodes. """ assert num_edges > 0 and num_edges < num_nodes row, col = torch.arange(num_edges), torch.randperm(num_edges) for i in range(num_edges, num_nodes): row = torch.cat([row, torch.full((num_edges, ), i, dtype=torch.long)]) choice = np.random.choice(torch.cat([row, col]).numpy(), num_edges) col = torch.cat([col, torch.from_numpy(choice)]) edge_index = torch.stack([row, col], dim=0) edge_index, _ = remove_self_loops(edge_index) edge_index = to_undirected(edge_index, num_nodes=num_nodes) return edge_index
def forward(self, x, edge_index): """""" edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) x = torch.mm(x, self.weight).view(-1, self.heads, self.out_channels) return self.propagate(edge_index, x=x, num_nodes=x.size(0))
def forward(self, x, edge_index, pseudo): """""" # See https://github.com/shchur/gnn-benchmark for the reference # TensorFlow implementation. edge_index, _ = remove_self_loops(edge_index) edge_index = add_self_loops(edge_index, num_nodes=x.size(0)) x = x.unsqueeze(-1) if x.dim() == 1 else x pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo row, col = edge_index F, (E, D) = x.size(0), pseudo.size() gaussian = -0.5 * (pseudo.view(E, 1, D) - self.mu.view(1, F, D))**2 gaussian = torch.exp(gaussian / (1e-14 + self.sigma.view(1, F, D)**2)) gaussian = gaussian.prod(dim=-1) # Normalize gaussians in edge dimension. gaussian_mean = scatter_add(gaussian, row, dim=0, dim_size=x.size(0)) gaussian = gaussian / (1e-14 + gaussian_mean[row]).view(E, F) out = scatter_add(x[col] * gaussian, row, dim=0, dim_size=x.size(0)) out = self.lin(out) return out
def forward(self, x, edge_index, return_attention_weights=False): """""" if torch.is_tensor(x): x = self.lin(x) x = (x, x) else: x = (self.lin(x[0]), self.lin(x[1])) edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=x[1].size(self.node_dim)) out = self.propagate(edge_index, x=x, return_attention_weights=return_attention_weights) if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out = out + self.bias if return_attention_weights: alpha, self.__alpha__ = self.__alpha__, None return out, (edge_index, alpha) else: return out
def process(self): for s, split in enumerate(['train', 'valid', 'test']): path = osp.join(self.raw_dir, '{}_graph.json').format(split) with open(path, 'r') as f: G = nx.DiGraph(json_graph.node_link_graph(json.load(f))) x = np.load(osp.join(self.raw_dir, '{}_feats.npy').format(split)) x = torch.from_numpy(x).to(torch.float) y = np.load(osp.join(self.raw_dir, '{}_labels.npy').format(split)) y = torch.from_numpy(y).to(torch.float) data_list = [] path = osp.join(self.raw_dir, '{}_graph_id.npy').format(split) idx = torch.from_numpy(np.load(path)).to(torch.long) idx = idx - idx.min() for i in range(idx.max().item() + 1): mask = idx == i G_s = G.subgraph(mask.nonzero().view(-1).tolist()) edge_index = torch.tensor(list(G_s.edges)).t().contiguous() edge_index = edge_index - edge_index.min() edge_index, _ = remove_self_loops(edge_index) data = Data(edge_index=edge_index, x=x[mask], y=y[mask]) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) torch.save(self.collate(data_list), self.processed_paths[s])
def forward(self, x, edge_index, mask=None, edge_weight=None, size=None): edge_index, _ = remove_self_loops(edge_index) if mask is not None: x = x * mask h = x return self.propagate(edge_index, size=size, x=x, h=h, edge_weight=edge_weight)
def norm(edge_index, num_nodes, edge_weight, dtype=None): edge_index, _ = remove_self_loops(edge_index) if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) row, col = edge_index deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 edge_index, edge_weight = add_self_loops(edge_index, edge_weight, 0, num_nodes) row, col = edge_index expand_deg = torch.zeros((edge_weight.size(0), ), dtype=dtype, device=edge_index.device) expand_deg[-num_nodes:] = torch.ones((num_nodes, ), dtype=dtype, device=edge_index.device) return edge_index, expand_deg - deg_inv_sqrt[ row] * edge_weight * deg_inv_sqrt[col]
def forward(self, x, edge_index, size=None): x_l = x_r = x alpha_l = (x_l * self.att_l).sum(dim=-1) alpha_r = (x_r * self.att_r).sum(dim=-1) # print(alpha_l.shape) # print(alpha_r.shape) if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) if size is not None: num_nodes = min(size[0], size[1]) edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) if isinstance(edge_index, SparseTensor): self.degree = edge_index.sum(dim=0) else: _, col = edge_index[0], edge_index[1] self.degree = degree(col) theta = self.get_relu_coefs(x, edge_index) self.theta = theta.view(-1, self.channels, 2 * self.k) out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), size=size) return out
def forward(self, x, edge_index): edge_index, _ = remove_self_loops(edge_index) edge_index = add_self_loops(edge_index, num_nodes=x.size(0)) x = x.unsqueeze(-1) if x.dim() == 1 else x row, col = edge_index if self.pool == 'mean': out = torch.matmul(x, self.weight) if self.bias is not None: out = out + self.bias out = self.act(out) out = scatter_mean(out[col], row, dim=0, dim_size=out.size(0)) elif self.pool == 'max': out = torch.matmul(x, self.weight) if self.bias is not None: out = out + self.bias out = self.act(out) out, _ = scatter_max(out[col], row, dim=0, dim_size=out.size(0)) elif self.pool == 'add': x = torch.matmul(x, self.weight) if self.bias is not None: out = out + self.bias out = self.act(out) out = scatter_add(x[col], row, dim=0, dim_size=x.size(0)) else: print('pooling not defined!') if self.normalize: out = F.normalize(out, p=2, dim=-1) return out
def compat_matrix_edge_idx(edge_idx, labels): """ c x c compatibility matrix, where c is number of classes H[i,j] is proportion of endpoints that are class j of edges incident to class i nodes "Generalizing GNNs Beyond Homophily" treats negative labels as unlabeled """ edge_index = remove_self_loops(edge_idx)[0] src_node, targ_node = edge_index[0, :], edge_index[1, :] labeled_nodes = (labels[src_node] >= 0) * (labels[targ_node] >= 0) label = labels.squeeze() c = label.max() + 1 H = torch.zeros((c, c)).to(edge_index.device) src_label = label[src_node[labeled_nodes]] targ_label = label[targ_node[labeled_nodes]] label_idx = torch.cat((src_label.unsqueeze(0), targ_label.unsqueeze(0)), axis=0) for k in range(c): sum_idx = torch.where(src_label == k)[0] add_idx = targ_label[sum_idx] scatter_add(torch.ones_like(add_idx).to(H.dtype), add_idx, out=H[k, :], dim=-1) H = H / torch.sum(H, axis=1, keepdims=True) return H
def forward(self, x, edge_index): edge_index, _ = remove_self_loops(edge_index) deg = degree(edge_index[1 if self.flow == 'source_to_target' else 0], x.size(0), dtype=torch.long) deg.clamp_(max=self.max_degree) if not self.root_weight: edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(self.node_dim)) h = self.propagate(edge_index, x=x) out = x.new_empty(list(x.size())[:-1] + [self.out_channels]) for i in deg.unique().tolist(): idx = (deg == i).nonzero().view(-1) r = self.rel_lins[i](h.index_select(self.node_dim, idx)) if self.root_weight: r = r + self.root_lins[i](x.index_select(self.node_dim, idx)) out.index_copy_(self.node_dim, idx, r) return out
def forward(self, x, edge_index, edge_attr=None): """""" x = x.unsqueeze(-1) if x.dim() == 1 else x edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) if edge_attr is None: edge_attr = x.new_ones((edge_index.size(1), )) assert edge_attr.dim() == 1 and edge_attr.numel() == edge_index.size(1) # Add self-loops to adjacency matrix. edge_index = add_self_loops(edge_index, x.size(0)) loop_value = x.new_full((x.size(0), ), 1 if not self.improved else 2) edge_attr = torch.cat([edge_attr, loop_value], dim=0) # Normalize adjacency matrix. row, col = edge_index deg = scatter_add(edge_attr, row, dim=0, dim_size=x.size(0)) deg = deg.pow(-0.5) deg[deg == float('inf')] = 0 edge_attr = deg[row] * edge_attr * deg[col] # Perform the convolution. out = torch.mm(x, self.weight) out = spmm(edge_index, edge_attr, out.size(0), out) if self.bias is not None: out = out + self.bias return out