def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: """""" if self.normalize: if isinstance(edge_index, Tensor): edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), improved=False, add_self_loops=False, dtype=x.dtype) elif isinstance(edge_index, SparseTensor): edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), add_self_loops=False, dtype=x.dtype) xs = [x] for k in range(self.K): # propagate_type: (x: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=xs[-1], edge_weight=edge_weight, size=None) xs.append(out) return self.lin(torch.cat(xs, dim=-1))
def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: """""" if self.normalize: if isinstance(edge_index, Tensor): edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), improved=False, add_self_loops=False, dtype=x.dtype) elif isinstance(edge_index, SparseTensor): edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), add_self_loops=False, dtype=x.dtype) out = self.lins[0](x) for lin in self.lins[1:]: # propagate_type: (x: Tensor, edge_weight: OptTensor) x = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) out += lin.forward(x) return out
def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: """""" if self.normalize: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if cache is None: edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) if self.cached: self._cached_edge_index = (edge_index, edge_weight) else: edge_index, edge_weight = cache[0], cache[1] elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache h = x for k in range(self.K): if self.dropout > 0 and self.training: if isinstance(edge_index, Tensor): assert edge_weight is not None edge_weight = F.dropout(edge_weight, p=self.dropout) else: value = edge_index.storage.value() assert value is not None value = F.dropout(value, p=self.dropout) edge_index = edge_index.set_value(value, layout='coo') # propagate_type: (x: Tensor, edge_weight: OptTensor) x = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) x = x * (1 - self.alpha) x += self.alpha * h return x
def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: r""" Args: x: The input node features of shape :obj:`[num_nodes, num_layers, channels]`. """ if x.dim() != 3: raise ValueError('Feature shape must be [num_nodes, num_layers, ' 'channels].') if self.normalize: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if cache is None: edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) if self.cached: self._cached_edge_index = (edge_index, edge_weight) else: edge_index, edge_weight = cache[0], cache[1] elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache # propagate_type: (x: Tensor, edge_weight: OptTensor) return self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None)
def forward(self, x: Tensor, x_0: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: """""" if self.normalize: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if cache is None: edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, dtype=x.dtype) if self.cached: self._cached_edge_index = (edge_index, edge_weight) else: edge_index, edge_weight = cache[0], cache[1] elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, dtype=x.dtype) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache # propagate_type: (x: Tensor, edge_weight: OptTensor) x = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) if self.weight2 is None: out = (1 - self.alpha) * x + self.alpha * x_0 out = (1 - self.beta) * out + self.beta * (out @ self.weight1) else: out1 = (1 - self.alpha) * x out1 = (1 - self.beta) * out1 + self.beta * (out1 @ self.weight1) out2 = self.alpha * x_0 out2 = (1 - self.beta) * out2 + self.beta * (out2 @ self.weight2) out = out1 + out2 return out
def __call__(self, data): assert data.edge_index is not None or data.adj_t is not None if data.edge_index is not None: edge_weight = data.edge_attr if 'edge_weight' in data: edge_weight = data.edge_weight data.edge_index, data.edge_weight = gcn_norm( data.edge_index, edge_weight, data.num_nodes) else: data.adj_t = gcn_norm(data.adj_t) return data
def forward(self, x, edge_index, edge_weight=None): if isinstance(edge_index, torch.Tensor): edge_index, norm = gcn_norm( edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype) elif isinstance(edge_index, SparseTensor): edge_index = gcn_norm( edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype) norm = None hidden = x*(self.temp[0]) for k in range(self.K): x = self.propagate(edge_index, x=x, norm=norm) gamma = self.temp[k+1] hidden = hidden + gamma*x return hidden
def __call__(self, data): assert 'edge_index' in data or 'adj_t' in data if 'edge_index' in data: edge_weight = data.edge_attr if 'edge_weight' in data: edge_weight = data.edge_weight data.edge_index, data.edge_weight = gcn_norm( data.edge_index, edge_weight, data.num_nodes, add_self_loops=self.add_self_loops) else: data.adj_t = gcn_norm(data.adj_t, add_self_loops=self.add_self_loops) return data
def compute_energy(self, x, edge_index, device): energy_list = [] edge_weight = None edge_index, edge_weight = gcn_norm(edge_index, edge_weight, x.size(0), False, dtype=x.dtype) adj_weight = to_dense_adj(edge_index, edge_attr=edge_weight) num_nodes = x.size(0) adj_weight = torch.squeeze(adj_weight, dim=0) laplacian_weight = torch.eye( num_nodes, dtype=torch.float, device=device) - adj_weight # compute energy in the first layer energy = self.Dirichlet_energy(x, laplacian_weight) energy_list.append(energy) if self.lin_first: x = self.SGC.lin(x) for k in range(self.num_layers): # propagate_type: (x: Tensor, edge_weight: OptTensor) x = self.SGC.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) # compute energy in the middle layer energy = self.Dirichlet_energy(x, laplacian_weight) energy_list.append(energy) return energy_list
def forward(self, x, edge_index, edge_weight: Optional[torch.Tensor] = None): """""" cache = self._cache if cache is not None: if edge_index.size(1) != cache[0]: raise RuntimeError( 'Cached {} number of edges, but found {}. Please disable ' 'the caching behavior of this layer by removing the ' '`cached=True` argument in its constructor.'.format( cache[0], edge_index.size(1))) x = cache[1] else: num_edges = edge_index.size(1) edge_index, norm = gcn_norm(edge_index, x.size(self.node_dim), edge_weight, dtype=x.dtype) for k in range(self.K): x = self.propagate(edge_index, x=x, norm=norm) if self.cached: self._cache = (num_edges, x) return self.lin(x)
def forward(self, x, edge_index, edge_weight=None, params=None): if params is None: params = OrderedDict(self.named_parameters()) bias = params.get("bias", None) if self.normalize: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if cache is None: edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), self.improved, self.add_self_loops, dtype=x.dtype) if self.cached: self._cached_edge_index = (edge_index, edge_weight) else: edge_index, edge_weight = cache[0], cache[1] elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), self.improved, self.add_self_loops, dtype=x.dtype) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache x = torch.matmul(x, params["weight"]) # propagate_type: (x: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) if self.bias is not None: out += bias
def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: """""" if isinstance(edge_index, Tensor): edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), add_self_loops=False, dtype=x.dtype) elif isinstance(edge_index, SparseTensor): edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), add_self_loops=False, dtype=x.dtype) x = x.unsqueeze(-3) out = x for t in range(self.num_layers): if t == 0: out = out @ self.init_weight else: out = out @ self.weight[0 if self.shared_weights else t - 1] # propagate_type: (x: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=out, edge_weight=edge_weight, size=None) root = F.dropout(x, p=self.dropout, training=self.training) out += root @ self.root_weight[0 if self.shared_weights else t] if self.bias is not None: out += self.bias[0 if self.shared_weights else t] if t < self.num_layers - 1: out = self.act(out) return out.mean(dim=-3)
def forward(self, x, edge_index, edge_weight=None): edge_index, norm = gcn_norm( edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype) hidden = x*(self.temp[0]) for k in range(self.K): x = self.propagate(edge_index, x=x, norm=norm) gamma = self.temp[k+1] hidden = hidden + gamma*x return hidden
def forward(self, data, train_idx): n = data.graph['num_nodes'] edge_index = data.graph['edge_index'] edge_weight=None if isinstance(edge_index, torch.Tensor): edge_index, edge_weight = gcn_norm( edge_index, edge_weight, n, False) row, col = edge_index # transposed if directed adj_t = SparseTensor(row=col, col=row, value=edge_weight, sparse_sizes=(n, n)) elif isinstance(edge_index, SparseTensor): edge_index = gcn_norm( edge_index, edge_weight, n, False) edge_weight=None adj_t = edge_index y = torch.zeros((n, self.out_channels)).to(adj_t.device()) if data.label.shape[1] == 1: # make one hot y[train_idx] = F.one_hot(data.label[train_idx], self.out_channels).squeeze(1).to(y) elif self.mult_bin: y = torch.zeros((n, 2*self.out_channels)).to(adj_t.device()) for task in range(data.label.shape[1]): y[train_idx, 2*task:2*task+2] = F.one_hot(data.label[train_idx, task], 2).to(y) else: y[train_idx] = data.label[train_idx].to(y.dtype) result = y.clone() for _ in range(self.num_iters): for _ in range(self.hops): result = matmul(adj_t, result) result *= self.alpha result += (1-self.alpha)*y if self.mult_bin: output = torch.zeros((n, self.out_channels)).to(result.device) for task in range(data.label.shape[1]): output[:, task] = result[:, 2*task+1] result = output return result
def neighborhood_aggregation(self, x, adj_t): if self.aggregator == 'gcn': adj_t = gcn_norm(adj_t, num_nodes=x.size(self.node_dim), add_self_loops=self.add_self_loops, dtype=x.dtype) elif self.add_self_loops: adj_t = adj_t.set_diag() for k in range(self.K): x = self.propagate(adj_t, x=x) return x
def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: """""" cache = self._cached_x if cache is None: if isinstance(edge_index, Tensor): edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) elif isinstance(edge_index, SparseTensor): edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, self.flow, dtype=x.dtype) for k in range(self.K): # propagate_type: (x: Tensor, edge_weight: OptTensor) x = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) if self.cached: self._cached_x = x else: x = cache.detach() return self.lin(x)
def forward( self, y: Tensor, edge_index: Adj, mask: Optional[Tensor] = None, edge_weight: OptTensor = None, post_step: Callable = lambda y: y.clamp_(0., 1.) ) -> Tensor: """""" if y.dtype == torch.long and y.size(0) == y.numel(): y = F.one_hot(y.view(-1)).to(torch.float) out = y if mask is not None: out = torch.zeros_like(y) out[mask] = y[mask] if isinstance(edge_index, SparseTensor) and not edge_index.has_value(): edge_index = gcn_norm(edge_index, add_self_loops=False) elif isinstance(edge_index, Tensor) and edge_weight is None: edge_index, edge_weight = gcn_norm(edge_index, num_nodes=y.size(0), add_self_loops=False) res = (1 - self.alpha) * out for _ in range(self.num_layers): # propagate_type: (y: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=out, edge_weight=edge_weight, size=None) out.mul_(self.alpha).add_(res) out = post_step(out) return out
def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: """""" if self.normalize and isinstance(edge_index, Tensor): out = gcn_norm(edge_index, edge_weight, x.size(self.node_dim), add_self_loops=False, dtype=x.dtype) edge_index, edge_weight = out elif self.normalize and isinstance(edge_index, SparseTensor): edge_index = gcn_norm(edge_index, None, x.size(self.node_dim), add_self_loops=False, dtype=x.dtype) # propagate_type: (x: Tensor, edge_weight: OptTensor) return self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None)
def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: if self.lin_first: x = self.lin(x) """""" cache = self._cached_x if cache is None: if isinstance(edge_index, Tensor): edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, dtype=x.dtype) elif isinstance(edge_index, SparseTensor): edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, dtype=x.dtype) for k in range(self.K): # propagate_type: (x: Tensor, edge_weight: OptTensor) x = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) if self.cached: self._cached_x = x else: x = cache if self.bn: x = self.bn(x) if self.dropout > 0.: x = F.dropout(x, p=self.dropout, training=self.training) if not self.lin_first: x = self.lin(x) return x
def neighborhood_aggregation(self, x, adj_t): if self.K <= 0: return x if self.normalize: adj_t = gcn_norm(adj_t, add_self_loops=False) if self.add_self_loops: adj_t = adj_t.set_diag() for k in range(self.K): x = self.propagate(adj_t, x=x) x = self.transform(x) return x
def forward(self, data): edge_index = data.graph['edge_index'] x = data.graph['node_feat'] x = self.lin(x) n = data.graph['num_nodes'] edge_weight=None if isinstance(edge_index, torch.Tensor): edge_index, edge_weight = gcn_norm( edge_index, edge_weight, n, False, dtype=x.dtype) row, col = edge_index adj_t = SparseTensor(row=col, col=row, value=edge_weight, sparse_sizes=(n, n)) elif isinstance(edge_index, SparseTensor): edge_index = gcn_norm( edge_index, edge_weight, n, False, dtype=x.dtype) edge_weight=None adj_t = edge_index for _ in range(self.hops): x = matmul(adj_t, x) return x
def forward(self, data): data = T.ToSparseTensor()(data) x, adj_t = data.x, data.adj_t adj_t = gcn_norm(adj_t) x = F.dropout(x, self.dropout, training=self.training) x = x_0 = self.lins[0](x).relu() for conv in self.convs: x = F.dropout(x, self.dropout, training=self.training) x = conv(x=x, x_0=x_0, edge_index=adj_t[0]) x = x.relu() z = x x = F.dropout(x, self.dropout, training=self.training) x = self.lins[1](x) return z, x.log_softmax(dim=-1)
def forward(self, x, edge_index, edge_weight: Optional[torch.Tensor] = None): """""" edge_index, norm = gcn_norm(edge_index, x.size(self.node_dim), edge_weight, dtype=x.dtype) hidden = x for k in range(self.K): x = self.propagate(edge_index, x=x, norm=norm) x = x * (1 - self.alpha) x = x + self.alpha * hidden return x
def forward(self, W: torch.FloatTensor, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: """""" if self.normalize: cache = self._cached_edge_index if cache is None: edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), self.improved, self.add_self_loops) x = torch.matmul(x, W) # propagate_type: (x: Tensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) return out
def forward(self, x, adj_t, edge_types): """Calculates embeddings""" num_genes = (x.shape[-2] - self.num_dis_nodes - self.num_comp_nodes - self.num_pathways) adj_t = gcn_norm(adj_t, num_nodes=x.size(-2), add_self_loops=False) x1 = F.relu(self.conv1(self.lin1(x), adj_t, edge_types)) x1 = torch.cat( ( self.normg1(x1[:num_genes]), self.normd1(x1[num_genes:num_genes + self.num_dis_nodes]), self.normc1(x1[num_genes + self.num_dis_nodes:num_genes + self.num_dis_nodes + self.num_comp_nodes]), self.normp1( x1[num_genes + self.num_dis_nodes + self.num_comp_nodes:]), ), 0, ) x2 = F.relu(self.conv2(x1, adj_t, edge_types)) x2 = torch.cat( ( self.normg2(x2[:num_genes]), self.normd2(x2[num_genes:num_genes + self.num_dis_nodes]), self.normc2(x2[num_genes + self.num_dis_nodes:num_genes + self.num_dis_nodes + self.num_comp_nodes]), self.normp2( x2[num_genes + self.num_dis_nodes + self.num_comp_nodes:]), ), 0, ) x3 = self.conv3(x2, adj_t, edge_types) x3 = torch.cat( ( self.normg3(x3[:num_genes]), self.normd3(x3[num_genes:num_genes + self.num_dis_nodes]), self.normc3(x3[num_genes + self.num_dis_nodes:num_genes + self.num_dis_nodes + self.num_comp_nodes]), self.normp3( x3[num_genes + self.num_dis_nodes + self.num_comp_nodes:]), ), 0, ) x3 = self.drop(x3) return x3
def __norm__(self, x, edge_index, edge_weight: Optional[torch.Tensor] = None): cache = self._cache if cache is not None: if edge_index.size(1) != cache[0]: raise RuntimeError( 'Cached {} number of edges, but found {}. Please disable ' 'the caching behavior of this layer by removing the ' '`cached=True` argument in its constructor.'.format( cache[0], edge_index.size(1))) return cache[1:] num_edges = edge_index.size(1) edge_index, edge_weight = gcn_norm(edge_index, x.size(self.node_dim), edge_weight, dtype=x.dtype) if self.cached: self._cache = (num_edges, edge_index, edge_weight) return edge_index, edge_weight
def forward(self, x: Tensor, edge_index: Adj) -> Tensor: """""" symnorm_weight: OptTensor = None if "symnorm" in self.aggregators: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if cache is None: edge_index, symnorm_weight = gcn_norm( # yapf: disable edge_index, None, num_nodes=x.size(self.node_dim), improved=False, add_self_loops=self.add_self_loops) if self.cached: self._cached_edge_index = (edge_index, symnorm_weight) else: edge_index, symnorm_weight = cache elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, None, num_nodes=x.size(self.node_dim), improved=False, add_self_loops=self.add_self_loops) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache elif self.add_self_loops: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if self.cached and cache is not None: edge_index = cache[0] else: edge_index, _ = add_remaining_self_loops(edge_index) if self.cached: self._cached_edge_index = (edge_index, None) elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if self.cached and cache is not None: edge_index = cache else: edge_index = fill_diag(edge_index, 1.0) if self.cached: self._cached_adj_t = edge_index # [num_nodes, (out_channels // num_heads) * num_bases] bases = self.bases_lin(x) # [num_nodes, num_heads * num_bases * num_aggrs] weightings = self.comb_lin(x) # [num_nodes, num_aggregators, (out_channels // num_heads) * num_bases] # propagate_type: (x: Tensor, symnorm_weight: OptTensor) aggregated = self.propagate(edge_index, x=bases, symnorm_weight=symnorm_weight, size=None) weightings = weightings.view(-1, self.num_heads, self.num_bases * len(self.aggregators)) aggregated = aggregated.view( -1, len(self.aggregators) * self.num_bases, self.out_channels // self.num_heads, ) # [num_nodes, num_heads, out_channels // num_heads] out = torch.matmul(weightings, aggregated) out = out.view(-1, self.out_channels) if self.bias is not None: out += self.bias return out
adj_t = torch.load(path) else: path_sym = dataset.root + '/mag240m/paper_to_paper_symmetric.pt' if osp.exists(path_sym): adj_t = torch.load(path_sym) else: edge_index = dataset.edge_index('paper', 'cites', 'paper') edge_index = torch.from_numpy(edge_index) adj_t = SparseTensor(row=edge_index[0], col=edge_index[1], sparse_sizes=(dataset.num_papers, dataset.num_papers), is_sorted=True) adj_t = adj_t.to_symmetric() torch.save(adj_t, path_sym) adj_t = gcn_norm(adj_t, add_self_loops=True) torch.save(adj_t, path) print(f'Done! [{time.perf_counter() - t:.2f}s]') train_idx = dataset.get_idx_split('train') valid_idx = dataset.get_idx_split('valid') test_idx = dataset.get_idx_split('test') num_features = dataset.num_paper_features pbar = tqdm(total=args.num_layers * (num_features // 128)) pbar.set_description('Pre-processing node features') for j in range(0, num_features, 128): # Run spmm in chunks... x = dataset.paper_feat[:, j:min(j + 128, num_features)] x = torch.from_numpy(x.astype(np.float32))
def forward(self, x: Tensor, x_0: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: """""" if self.normalize: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if cache is None: edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, dtype=x.dtype) if self.cached: self._cached_edge_index = (edge_index, edge_weight) else: edge_index, edge_weight = cache[0], cache[1] elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), False, self.add_self_loops, dtype=x.dtype) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache # propagate_type: (x: Tensor, edge_weight: OptTensor) x = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) x.mul_(1 - self.alpha) x_0 = self.alpha * x_0[:x.size(0)] if self.weight2 is None: out = x.add_(x_0) out = torch.addmm(out, out, self.weight1, beta=1. - self.beta, alpha=self.beta) else: out = torch.addmm(x, x, self.weight1, beta=1. - self.beta, alpha=self.beta) out += torch.addmm(x_0, x_0, self.weight2, beta=1. - self.beta, alpha=self.beta) return out
t = time.perf_counter() print('Reading adjacency matrix...', end=' ', flush=True) path = f'{dataset.dir}/paper_to_paper_symmetric.pt' if osp.exists(path): adj_t = torch.load(path) else: edge_index = dataset.edge_index('paper', 'cites', 'paper') edge_index = torch.from_numpy(edge_index) adj_t = SparseTensor( row=edge_index[0], col=edge_index[1], sparse_sizes=(dataset.num_papers, dataset.num_papers), is_sorted=True) adj_t = adj_t.to_symmetric() torch.save(adj_t, path) adj_t = gcn_norm(adj_t, add_self_loops=False) if args.low_memory: adj_t = adj_t.to(torch.half) print(f'Done! [{time.perf_counter() - t:.2f}s]') train_idx = dataset.get_idx_split('train') valid_idx = dataset.get_idx_split('valid') test_idx = dataset.get_idx_split('test') y_train = torch.from_numpy(dataset.paper_label[train_idx]).to(torch.long) y_valid = torch.from_numpy(dataset.paper_label[valid_idx]).to(torch.long) model = LabelPropagation(args.num_layers, args.alpha) N, C = dataset.num_papers, dataset.num_classes