def norm(edge_index, num_nodes, edge_weight=None, improved=False, dtype=None): if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) fill_value = 1 if not improved else 2 edge_index, edge_weight = add_remaining_self_loops( edge_index, edge_weight, fill_value, num_nodes) row, col = edge_index deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, add_self_loops=True, dtype=None): fill_value = 2. if improved else 1. if isinstance(edge_index, SparseTensor): adj_t = edge_index if not adj_t.has_value(): adj_t = adj_t.fill_value(1., dtype=dtype) if add_self_loops: adj_t = fill_diag(adj_t, fill_value) deg = sum(adj_t, dim=1) deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.) adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1)) adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1)) return adj_t else: num_nodes = maybe_num_nodes(edge_index, num_nodes) if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) if add_self_loops: edge_index, tmp_edge_weight = add_remaining_self_loops( edge_index, edge_weight, fill_value, num_nodes) assert tmp_edge_weight is not None edge_weight = tmp_edge_weight row, col = edge_index[0], edge_index[1] deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def compute_identity(edge_index, n, k): edge_weight = torch.ones((edge_index.size(1), ), dtype=torch.float, device=edge_index.device) edge_index, edge_weight = pyg_utils.add_remaining_self_loops( edge_index, edge_weight, 1, n) adj_sparse = torch.sparse.FloatTensor(edge_index, edge_weight, torch.Size([n, n])) adj = adj_sparse.to_dense() deg = torch.diag(torch.sum(adj, -1)) deg_inv_sqrt = deg.pow(-0.5) adj = deg_inv_sqrt @ adj @ deg_inv_sqrt diag_all = [torch.diag(adj)] adj_power = adj for i in range(1, k): adj_power = adj_power @ adj diag_all.append(torch.diag(adj_power)) diag_all = torch.stack(diag_all, dim=1) return diag_all
def forward(self, data): x, edge_index, y, batch = data.x, data.edge_index, data.y, data.batch edge_index, _ = add_remaining_self_loops(edge_index, num_nodes=x.shape[0]) edge_list = [] perm_list = [] shape_list = [] edge_weight = None f, e, b = x, edge_index, batch for i in range(self.depth): if i < self.depth: edge_list.append(e) f, attn = self.down_list[i](f, e, self.direction) shape_list.append(f.shape) f = F.leaky_relu(f) f, e, _, b, perm, _ = self.pool_list[i](f, e, edge_weight, b, attn) if i < self.depth - 1: e, _ = self.augment_adj(e, None, f.shape[0]) perm_list.append(perm) latent_x, latent_edge = f, e z = f for i in range(self.depth): index = self.depth - i - 1 shape = shape_list[index] up = torch.zeros(shape).to(self.device) p = perm_list[index] up[p] = z z = self.up_list[i](up, edge_list[index]) if i < self.depth - 1: z = torch.relu(z) edge_list.clear() perm_list.clear() shape_list.clear() return z, latent_x, latent_edge, b
def forward(self, x, edge_index): edge_index, _ = add_remaining_self_loops(edge_index) row, col = edge_index deg = degree(row) deg_inv_sqrt = deg.pow(-0.5) norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] if self.norm == 'neighbornorm': x_j = self.normlayer(x, edge_index) else: x_j = x[col] x_j = norm.view(-1, 1) * x_j out = scatter_add(src=x_j, index=row, dim=0, dim_size=x.size(0)) if self.norm == 'neighbornorm': out = F.relu(self.linear(out)) else: out = self.normlayer(F.relu(self.linear(out))) return out
def forward(self, x, edge_index, edge_attr, edge_weight=None): x = torch.matmul(x, self.weight) if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=x.dtype, device=edge_index.device) fill_value = 1 if not self.improved else 2 edge_index, edge_weight = add_remaining_self_loops( edge_index, edge_weight, fill_value, x.size(0)) self_loop_edges = torch.zeros(x.size(0), edge_attr.size(1)).to(edge_index.device) edge_attr = torch.cat([edge_attr, self_loop_edges], dim=0) row, col = edge_index deg = scatter_add(edge_weight, row, dim=0, dim_size=x.size(0)) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 norm = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] return self.propagate(edge_index, x=x, edge_attr=edge_attr, norm=norm)
def forward(self, x, edge_index, edge_weight=None, size=None, res_n_id=None): """ Args: res_n_id (Tensor, optional): Residual node indices coming from :obj:`DataFlow` generated by :obj:`NeighborSampler` are used to select central node features in :obj:`x`. Required if operating in a bipartite graph and :obj:`concat` is :obj:`True`. (default: :obj:`None`) """ if not self.concat and torch.is_tensor(x): edge_index, edge_weight = add_remaining_self_loops( edge_index, edge_weight, 1, x.size(self.node_dim)) return self.propagate(edge_index, size=size, x=x, edge_weight=edge_weight, res_n_id=res_n_id)
def forward(self, x, edge_index, edge_weight=None, pseudo=None, size=None): """""" edge_weight = edge_weight.squeeze() if size is None and torch.is_tensor(x): edge_index, edge_weight = add_remaining_self_loops( edge_index, edge_weight, 1, x.size(0)) weight = self.nn(pseudo).view(-1, self.in_channels, self.out_channels) if torch.is_tensor(x): x = torch.matmul(x.unsqueeze(1), weight).squeeze(1) else: x = (None if x[0] is None else torch.matmul(x[0].unsqueeze(1), weight).squeeze(1), None if x[1] is None else torch.matmul(x[1].unsqueeze(1), weight).squeeze(1)) # weight = self.nn(pseudo).view(-1, self.out_channels,self.in_channels) # if torch.is_tensor(x): # x = torch.matmul(x.unsqueeze(1), weight.permute(0,2,1)).squeeze(1) # else: # x = (None if x[0] is None else torch.matmul(x[0].unsqueeze(1), weight).squeeze(1), # None if x[1] is None else torch.matmul(x[1].unsqueeze(1), weight).squeeze(1)) return self.propagate(edge_index, size=size, x=x, edge_weight=edge_weight)
def graph_connectivity(device, perm, edge_index, edge_weight, score, ratio, batch, N): r"""graph_connectivity: is a function which internally calls StAS func to maintain graph connectivity""" kN = perm.size(0) perm2 = perm.view(-1, 1) # mask contains bool mask of edges which originate from perm (selected) nodes mask = (edge_index[0] == perm2).sum(0, dtype=torch.bool) # create the S S0 = edge_index[1][mask].view(1, -1) S1 = edge_index[0][mask].view(1, -1) index_S = torch.cat([S0, S1], dim=0) value_S = score[mask].detach().squeeze() # relabel for pooling ie: make S [N x kN] n_idx = torch.zeros(N, dtype=torch.long) n_idx[perm] = torch.arange(perm.size(0)) index_S[1] = n_idx[index_S[1]] # create A index_A = edge_index.clone() if edge_weight is None: value_A = value_S.new_ones(edge_index[0].size(0)) else: value_A = edge_weight.clone() fill_value = 1 index_E, value_E = StAS(index_A, value_A, index_S, value_S, device, N, kN) index_E, value_E = remove_self_loops(edge_index=index_E, edge_attr=value_E) index_E, value_E = add_remaining_self_loops(edge_index=index_E, edge_weight=value_E, fill_value=fill_value, num_nodes=kN) return index_E, value_E
def load_cora(): edges = pd.read_csv(CORA + 'cora_cites.csv') data = pd.read_csv(CORA + 'cora_content.csv') id_to_node = dict([(row['paper_id'], idx) for idx, row in data.iterrows()]) class_to_int = dict([(c, i) for i, c in enumerate(set(data['label']))]) # COO matrix of edges converted to node ids to match the # feature tensor citing = [id_to_node[e] for e in edges['citing_paper_id']] cited = [id_to_node[e] for e in edges['cited_paper_id']] # Undirected since there are so many orphans otherwise ei = torch.tensor([ citing, # + cited, cited, # + citing ]) ei = add_remaining_self_loops(ei)[0] # Don't need paper id's or class in node attr vectors X = torch.tensor(data.iloc[:, 1:-1].values, dtype=torch.float) y = torch.zeros(X.size()[0], len(class_to_int)) i = 0 for c in data['label']: y[i][class_to_int[c]] = 1 i += 1 weights = y.sum(dim=0) weights = weights.max() / weights return Data(x=X, edge_index=ei, y=y, weights=weights, num_nodes=X.size()[0])
def norm(edge_index, num_nodes, edge_weight=None, improved=False, dtype=None): if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) # edge_index는 [2, num_edge] 형태, 순서대로 row, column index가 됨 # edge_weight는 num_edges만큼 fill_value = 1 if not improved else 2 # self loop 1 더해줌 edge_index, edge_weight = add_remaining_self_loops( edge_index, edge_weight, fill_value, num_nodes) row, col = edge_index # degree sum of edge weights deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 # result still in COO form return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def diag_enhance_norm(edge_index, num_nodes, edge_weight=None, improved=False, dtype=None, diag_lambda=1.0): if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) fill_value = 1 if not improved else 2 edge_index, edge_weight = add_remaining_self_loops( edge_index, edge_weight, fill_value, num_nodes) row, col = edge_index deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 norm_edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] diag_edge_weight = norm_edge_weight.clone() diag_edge_weight[edge_index[0] != edge_index[1]] = 0 return (edge_index, norm_edge_weight + diag_lambda * diag_edge_weight)
def My_norms(self, x_norm, edge_index, num_nodes, edge_weight=None, improved=False, dtype=None): if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) fill_value = 1 if not improved else 2 edge_index, edge_weight = add_remaining_self_loops( edge_index, edge_weight, fill_value, num_nodes) edge_index_j = edge_index[0] edge_index_i = edge_index[1] x_norm_j = x_norm[edge_index_j] x_norm_i = x_norm[edge_index_i] alpha = self.beta * (x_norm_i * x_norm_j).sum(dim=-1) alpha = softmax(alpha, edge_index_i, num_nodes) return edge_index, alpha
def __dropout_adj__(self, sparse_adj: SparseTensor, dropout_adj_prob: float): # number of nodes N = sparse_adj.size(0) # sparse adj matrix to dense adj matrix row, col, edge_attr = sparse_adj.coo() edge_index = torch.stack([row, col], dim=0) # dropout adjacency matrix -> generalization edge_index, edge_attr = dropout_adj(edge_index, edge_attr=edge_attr, p=dropout_adj_prob, force_undirected=True, training=self.training) # because dropout removes self-loops (due to force_undirected=True), make sure to add them back again edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_weight=edge_attr, fill_value=0.00, num_nodes=N) # dense adj matrix to sparse adj matrix sparse_adj = SparseTensor.from_edge_index(edge_index, edge_attr=edge_attr, sparse_sizes=(N, N)) return sparse_adj
def norm(edge_index, num_nodes, edge_weight, dtype=None): if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) row, col = edge_index deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 edge_index, edge_weight = add_remaining_self_loops( edge_index, edge_weight, 0, num_nodes) row, col = edge_index expand_deg = torch.zeros((edge_weight.size(0), ), dtype=dtype, device=edge_index.device) expand_deg[-num_nodes:] = torch.ones((num_nodes, ), dtype=dtype, device=edge_index.device) return edge_index, expand_deg - deg_inv_sqrt[ row] * edge_weight * deg_inv_sqrt[col]
def forward(self, x: Tensor, edge_index: Adj) -> Tensor: """""" symnorm_weight: OptTensor = None if "symnorm" in self.aggregators: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if cache is None: edge_index, symnorm_weight = gcn_norm( # yapf: disable edge_index, None, num_nodes=x.size(self.node_dim), improved=False, add_self_loops=self.add_self_loops) if self.cached: self._cached_edge_index = (edge_index, symnorm_weight) else: edge_index, symnorm_weight = cache elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, None, num_nodes=x.size(self.node_dim), improved=False, add_self_loops=self.add_self_loops) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache elif self.add_self_loops: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if self.cached and cache is not None: edge_index = cache[0] else: edge_index, _ = add_remaining_self_loops(edge_index) if self.cached: self._cached_edge_index = (edge_index, None) elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if self.cached and cache is not None: edge_index = cache else: edge_index = fill_diag(edge_index, 1.0) if self.cached: self._cached_adj_t = edge_index # [num_nodes, (out_channels // num_heads) * num_bases] bases = self.bases_lin(x) # [num_nodes, num_heads * num_bases * num_aggrs] weightings = self.comb_lin(x) # [num_nodes, num_aggregators, (out_channels // num_heads) * num_bases] # propagate_type: (x: Tensor, symnorm_weight: OptTensor) aggregated = self.propagate(edge_index, x=bases, symnorm_weight=symnorm_weight, size=None) weightings = weightings.view(-1, self.num_heads, self.num_bases * len(self.aggregators)) aggregated = aggregated.view( -1, len(self.aggregators) * self.num_bases, self.out_channels // self.num_heads, ) # [num_nodes, num_heads, out_channels // num_heads] out = torch.matmul(weightings, aggregated) out = out.view(-1, self.out_channels) if self.bias is not None: out += self.bias return out
def norm(self, edge_index, num_nodes, edge_weight=None, improved=False, dtype=None): adj_dict = {} def add_edge(a, b): if a in adj_dict: neighbors = adj_dict[a] else: neighbors = set() adj_dict[a] = neighbors if b not in neighbors: neighbors.add(b) cpu_device = torch.device("cpu") gpu_device = torch.device("cuda") for a, b in edge_index.t().detach().to(cpu_device).numpy(): a = int(a) b = int(b) add_edge(a, b) add_edge(b, a) adj_dict = {a: list(neighbors) for a, neighbors in adj_dict.items()} def sample_neighbor(a): neighbors = adj_dict[a] random_index = np.random.randint(0, len(neighbors)) return neighbors[random_index] # word_counter = Counter() walk_counters = {} def norm(counter): s = sum(counter.values()) new_counter = Counter() for a, count in counter.items(): new_counter[a] = counter[a] / s return new_counter for _ in tqdm(range(40)): for a in adj_dict: current_a = a current_path_len = np.random.randint(1, self.path_len + 1) for _ in range(current_path_len): b = sample_neighbor(current_a) if a in walk_counters: walk_counter = walk_counters[a] else: walk_counter = Counter() walk_counters[a] = walk_counter walk_counter[b] += 1 current_a = b normed_walk_counters = { a: norm(walk_counter) for a, walk_counter in walk_counters.items() } prob_sums = Counter() for a, normed_walk_counter in normed_walk_counters.items(): for b, prob in normed_walk_counter.items(): prob_sums[b] += prob ppmis = {} for a, normed_walk_counter in normed_walk_counters.items(): for b, prob in normed_walk_counter.items(): ppmi = np.log(prob / prob_sums[b] * len(prob_sums) / self.path_len) ppmis[(a, b)] = ppmi new_edge_index = [] edge_weight = [] for (a, b), ppmi in ppmis.items(): new_edge_index.append([a, b]) edge_weight.append(ppmi) edge_index = torch.tensor(new_edge_index).t().to(gpu_device) edge_weight = torch.tensor(edge_weight).to(gpu_device) fill_value = 1 if not improved else 2 edge_index, edge_weight = add_remaining_self_loops( edge_index, edge_weight, fill_value, num_nodes) row, col = edge_index deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 return edge_index, (deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]).type(torch.float32)
def graph_max_pool(x, edge_index): edge_index, _ = add_remaining_self_loops(edge_index) source = edge_index[0] dest = edge_index[1] return scatter_('max', x[dest], source, dim_size=len(x))
def forward(self, data, edge_dropout=None, penalty_coefficient=0.25): x = data.x edge_index = data.edge_index batch = data.batch num_graphs = batch.max().item() + 1 row, col = edge_index total_num_edges = edge_index.shape[1] N_size = x.shape[0] if edge_dropout is not None: edge_index = dropout_adj( edge_index, edge_attr=(torch.ones(edge_index.shape[1], device=device)).long(), p=edge_dropout, force_undirected=True)[0] edge_index = add_remaining_self_loops(edge_index, num_nodes=batch.shape[0])[0] reduced_num_edges = edge_index.shape[1] current_edge_percentage = (reduced_num_edges / total_num_edges) no_loop_index, _ = remove_self_loops(edge_index) no_loop_row, no_loop_col = no_loop_index xinit = x.clone() x = x.unsqueeze(-1) mask = get_mask(x, edge_index, 1).to(x.dtype) x = F.leaky_relu(self.conv1(x, edge_index)) # +x x = x * mask x = self.gnorm(x) x = self.bn1(x) for conv, bn in zip(self.convs, self.bns): if (x.dim() > 1): x = x + F.leaky_relu(conv(x, edge_index)) mask = get_mask(mask, edge_index, 1).to(x.dtype) x = x * mask x = self.gnorm(x) x = bn(x) xpostconvs = x.detach() # x = F.leaky_relu(self.lin1(x)) x = x * mask xpostlin1 = x.detach() x = F.leaky_relu(self.lin2(x)) x = x * mask #calculate min and max batch_max = scatter_max(x, batch, 0, dim_size=N_size)[0] batch_max = torch.index_select(batch_max, 0, batch) batch_min = scatter_min(x, batch, 0, dim_size=N_size)[0] batch_min = torch.index_select(batch_min, 0, batch) #min-max normalize x = (x - batch_min) / (batch_max + 1e-6 - batch_min) probs = x #calculating the terms for the expected distance between clique and graph pairwise_prodsums = torch.zeros(num_graphs, device=device) for graph in range(num_graphs): batch_graph = (batch == graph) pairwise_prodsums[graph] = (torch.conv1d( probs[batch_graph].unsqueeze(-1), probs[batch_graph].unsqueeze(-1))).sum() / 2 ###calculate loss terms self_sums = scatter_add((probs * probs), batch, 0, dim_size=num_graphs) expected_weight_G = scatter_add( probs[no_loop_row] * probs[no_loop_col], batch[no_loop_row], 0, dim_size=num_graphs) / 2. expected_clique_weight = (pairwise_prodsums.unsqueeze(-1) - self_sums) / 1. expected_distance = (expected_clique_weight - expected_weight_G) ###calculate loss expected_loss = (penalty_coefficient ) * expected_distance * 0.5 - 0.5 * expected_weight_G loss = expected_loss retdict = {} retdict["output"] = [probs.squeeze(-1), "hist"] #output retdict["losses histogram"] = [loss.squeeze(-1), "hist"] retdict["Expected weight(G)"] = [expected_weight_G.mean(), "sequence"] retdict["Expected maximum weight"] = [ expected_clique_weight.mean(), "sequence" ] retdict["Expected distance"] = [expected_distance.mean(), "sequence"] retdict["loss"] = [loss.mean().squeeze(), "sequence"] #final loss return retdict
def main(): global device global graphname print(socket.gethostname()) seed = 0 if not download: mp.set_start_method('spawn', force=True) outputs = None if "OMPI_COMM_WORLD_RANK" in os.environ.keys(): os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] # Initialize distributed environment with SLURM if "SLURM_PROCID" in os.environ.keys(): os.environ["RANK"] = os.environ["SLURM_PROCID"] if "SLURM_NTASKS" in os.environ.keys(): os.environ["WORLD_SIZE"] = os.environ["SLURM_NTASKS"] if "MASTER_ADDR" not in os.environ.keys(): os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "1234" dist.init_process_group(backend='nccl') rank = dist.get_rank() size = dist.get_world_size() print("Processes: " + str(size)) # device = torch.device('cpu') devid = rank_to_devid(rank, acc_per_rank) device = torch.device('cuda:{}'.format(devid)) torch.cuda.set_device(device) curr_devid = torch.cuda.current_device() # print(f"curr_devid: {curr_devid}", flush=True) devcount = torch.cuda.device_count() if graphname == "Cora": path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', graphname) dataset = Planetoid(path, graphname, T.NormalizeFeatures()) data = dataset[0] data = data.to(device) data.x.requires_grad = True inputs = data.x.to(device) inputs.requires_grad = True data.y = data.y.to(device) edge_index = data.edge_index num_features = dataset.num_features num_classes = dataset.num_classes elif graphname == "Reddit": path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', graphname) dataset = Reddit(path, T.NormalizeFeatures()) data = dataset[0] data = data.to(device) data.x.requires_grad = True inputs = data.x.to(device) inputs.requires_grad = True data.y = data.y.to(device) edge_index = data.edge_index num_features = dataset.num_features num_classes = dataset.num_classes elif graphname == 'Amazon': # path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', graphname) # edge_index = torch.load(path + "/processed/amazon_graph.pt") # edge_index = torch.load("/gpfs/alpine/bif115/scratch/alokt/Amazon/processed/amazon_graph_jsongz.pt") # edge_index = edge_index.t_() print(f"Loading coo...", flush=True) edge_index = torch.load("../data/Amazon/processed/data.pt") print(f"Done loading coo", flush=True) # n = 9430088 n = 14249639 num_features = 300 num_classes = 24 # mid_layer = 24 inputs = torch.rand(n, num_features) data = Data() data.y = torch.rand(n).uniform_(0, num_classes - 1).long() data.train_mask = torch.ones(n).long() # edge_index = edge_index.to(device) print(f"edge_index.size: {edge_index.size()}", flush=True) print(f"edge_index: {edge_index}", flush=True) data = data.to(device) # inputs = inputs.to(device) inputs.requires_grad = True data.y = data.y.to(device) elif graphname == 'subgraph3': # path = "/gpfs/alpine/bif115/scratch/alokt/HipMCL/" # print(f"Loading coo...", flush=True) # edge_index = torch.load(path + "/processed/subgraph3_graph.pt") # print(f"Done loading coo", flush=True) print(f"Loading coo...", flush=True) edge_index = torch.load("../data/subgraph3/processed/data.pt") print(f"Done loading coo", flush=True) n = 8745542 num_features = 128 # mid_layer = 512 # mid_layer = 64 num_classes = 256 inputs = torch.rand(n, num_features) data = Data() data.y = torch.rand(n).uniform_(0, num_classes - 1).long() data.train_mask = torch.ones(n).long() print(f"edge_index.size: {edge_index.size()}", flush=True) data = data.to(device) inputs.requires_grad = True data.y = data.y.to(device) if download: exit() if normalization: adj_matrix, _ = add_remaining_self_loops(edge_index, num_nodes=inputs.size(0)) else: adj_matrix = edge_index init_process(rank, size, inputs, adj_matrix, data, num_features, num_classes, device, outputs, run) if outputs is not None: return outputs[0]
def forward(self, x, edge_index, edge_attr, batch=None): if batch is None: batch = edge_index.new_zeros(x.size(0)) # replace with MI x_information_score = self.calc_information_score( x, edge_index, edge_attr) score = torch.sum(torch.abs(x_information_score), dim=1) # Graph Pooling original_x = x perm = topk(score, self.ratio, batch) x = x[perm] batch = batch[perm] induced_edge_index, induced_edge_attr = filter_adj( edge_index, edge_attr, perm, num_nodes=score.size(0)) # Discard structure learning layer, directly return if self.sl is False: return x, induced_edge_index, induced_edge_attr, batch # Structure Learning if self.sample: # A fast mode for large graphs. # In large graphs, learning the possible edge weights between each pair of nodes is time consuming. # To accelerate this process, we sample it's K-Hop neighbors for each node and then learn the # edge weights between them. k_hop = 3 if edge_attr is None: edge_attr = torch.ones((edge_index.size(1), ), dtype=torch.float, device=edge_index.device) hop_data = Data(x=original_x, edge_index=edge_index, edge_attr=edge_attr) for _ in range(k_hop - 1): hop_data = self.neighbor_augment(hop_data) hop_edge_index = hop_data.edge_index hop_edge_attr = hop_data.edge_attr new_edge_index, new_edge_attr = filter_adj(hop_edge_index, hop_edge_attr, perm, num_nodes=score.size(0)) new_edge_index, new_edge_attr = add_remaining_self_loops( new_edge_index, new_edge_attr, 0, x.size(0)) row, col = new_edge_index weights = (torch.cat([x[row], x[col]], dim=1) * self.att).sum(dim=-1) weights = F.leaky_relu( weights, self.negative_slop) + new_edge_attr * self.lamb adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device) adj[row, col] = weights new_edge_index, weights = dense_to_sparse(adj) row, col = new_edge_index if self.sparse: new_edge_attr = self.sparse_attention(weights, row) else: new_edge_attr = softmax(weights, row, x.size(0)) # filter out zero weight edges adj[row, col] = new_edge_attr new_edge_index, new_edge_attr = dense_to_sparse(adj) # release gpu memory del adj torch.cuda.empty_cache() else: # Learning the possible edge weights between each pair of nodes in the pooled subgraph, relative slower. if edge_attr is None: induced_edge_attr = torch.ones( (induced_edge_index.size(1), ), dtype=x.dtype, device=induced_edge_index.device) num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0) shift_cum_num_nodes = torch.cat( [num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0) cum_num_nodes = num_nodes.cumsum(dim=0) adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device) # Construct batch fully connected graph in block diagonal matirx format for idx_i, idx_j in zip(shift_cum_num_nodes, cum_num_nodes): adj[idx_i:idx_j, idx_i:idx_j] = 1.0 new_edge_index, _ = dense_to_sparse(adj) row, col = new_edge_index weights = (torch.cat([x[row], x[col]], dim=1) * self.att).sum(dim=-1) weights = F.leaky_relu(weights, self.negative_slop) adj[row, col] = weights induced_row, induced_col = induced_edge_index adj[induced_row, induced_col] += induced_edge_attr * self.lamb weights = adj[row, col] if self.sparse: new_edge_attr = self.sparse_attention(weights, row) else: new_edge_attr = softmax(weights, row, x.size(0)) # filter out zero weight edges adj[row, col] = new_edge_attr new_edge_index, new_edge_attr = dense_to_sparse(adj) # release gpu memory del adj torch.cuda.empty_cache() return x, new_edge_index, new_edge_attr, batch
def forward(self, x, pos_edge_index, neg_edge_index, return_attention_weights=True): """""" # hyper linear pos_edge_index = add_remaining_self_loops(pos_edge_index, num_nodes=x.size(0))[0] x = self.manifolds.proj(self.manifolds.expmap0( self.manifolds.proj_tan0(x, self.c), c=self.c), c=self.c) if self.manifolds.name != 'PoincareBall': drop_weight = F.dropout(self.weight, self.dropout, training=self.training) mv = self.manifolds.mobius_matvec(drop_weight, x, self.c) res = self.manifolds.proj(mv, self.c) else: res = x if torch.isnan(res).any(): print("check here") # assert not torch.isnan(res).any() if self.use_bias: bias = self.manifolds.proj_tan0(self.bias.view(1, -1), self.c) hyp_bias = self.manifolds.expmap0(bias, self.c) hyp_bias = self.manifolds.proj(hyp_bias, self.c) res = self.manifolds.mobius_add(res, hyp_bias, c=self.c) res = self.manifolds.proj(res, self.c) torch.cuda.empty_cache() x = (self.manifolds.logmap0(res, c=self.c)).cuda() + 1e-15 if self.first_aggr: if self.manifolds.name == 'Hyperboloid': assert x.size(1) == self.in_channels - 1 else: assert x.size(1) == self.in_channels if return_attention_weights: x_trans_pos = (self.lin_pos_agg(x), self.lin_pos_agg(x)) x_trans_neg = (self.lin_neg_agg(x), self.lin_neg_agg(x)) else: x_trans_pos = x x_trans_neg = x x_pos = torch.cat([ self.propagate( pos_edge_index, x=x_trans_pos, size=None, return_attention_weights=return_attention_weights), x ], dim=1) x_neg = torch.cat([ self.propagate( neg_edge_index, x=x_trans_neg, size=None, return_attention_weights=return_attention_weights), x ], dim=1) else: assert x.size(1) == 2 * self.in_channels x_1, x_2 = x[:, :self.in_channels], x[:, self.in_channels:] x_pos = torch.cat([ self.propagate( pos_edge_index, x=(self.lin_pos_agg(x_1), self.lin_pos_agg(x_1)), size=None, return_attention_weights=return_attention_weights), self.propagate( neg_edge_index, x=(self.lin_neg_agg(x_2), self.lin_neg_agg(x_2)), size=None, return_attention_weights=return_attention_weights), x_1, ], dim=1) x_neg = torch.cat([ self.propagate( pos_edge_index, x=(self.lin_pos_agg(x_2), self.lin_pos_agg(x_2)), size=None, return_attention_weights=return_attention_weights), self.propagate( neg_edge_index, x=(self.lin_neg_agg(x_1), self.lin_neg_agg(x_1)), size=None, return_attention_weights=return_attention_weights), x_2, ], dim=1) # to ensure numetrical stable x_pos = x_pos + 1e-15 x_neg = x_neg + 1e-15 assert not torch.isnan(x_pos).any() assert not torch.isnan(x_neg).any() x_pos = self.manifolds.proj(self.manifolds.expmap0(self.lin_pos(x_pos), c=self.c), c=self.c) x_neg = self.manifolds.proj(self.manifolds.expmap0(self.lin_neg(x_neg), c=self.c), c=self.c) x_out = torch.cat([x_pos, x_neg], dim=1) xt = self.act(self.manifolds.logmap0(x_out, c=self.c), self.negative_slope) xt = self.manifolds.proj_tan0(xt, c=self.c) xt = self.manifolds.proj(self.manifolds.expmap0(xt, c=self.c), c=self.c) if torch.isnan(xt).any(): print("check here") assert not torch.isnan(xt).any() return xt
mean = torch.mean(mean, dim=-1, keepdim=True) var = scatter_mean((x[col] - mean[row])**2, row, dim=0, dim_size=x.size(0)) var = torch.mean(var, dim=-1, keepdim=True) # std = scatter_std(x[col], row, dim=0, dim_size=x.size(0)) out = (x[col] - mean[row]) / (var[row] + self.eps).sqrt() # out = (x[col] - mean[row]) / (std[row]**2 + self.eps).sqrt() out = self.gamma * out + self.beta return out if __name__ == '__main__': from torch_geometric.data import Data from torch_geometric.utils import add_remaining_self_loops edge_index = torch.tensor( [[0, 1], [1, 0], [1, 2], [2, 1], [2, 3], [2, 4], [3, 2], [4, 2]], dtype=torch.long) x = torch.tensor([[-1, 2, 3], [3, 2, 1], [1, 6, 9], [2, 3, 6], [3, 2, 8]], dtype=torch.float) data = Data(x=x, edge_index=edge_index.t().contiguous()) edge_index, _ = add_remaining_self_loops(data.edge_index) row, col = edge_index x = data.x print(x[col]) neighbornorm = NeighborNorm(3) y = neighbornorm.forward(x, edge_index) print(y)
def forward(self, x, edge_index, edge_weight=None, batch=None): if batch is None: batch = edge_index.new_zeros(x.size(0)) # NxF x = x.unsqueeze(-1) if x.dim() == 1 else x # Add Self Loops fill_value = 1 num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0) edge_index, edge_weight = add_remaining_self_loops( edge_index=edge_index, edge_weight=edge_weight, fill_value=fill_value, num_nodes=num_nodes.sum()) N = x.size(0) # total num of nodes in batch # ExF x_pool = self.gnn_intra_cluster(x=x, edge_index=edge_index, edge_weight=edge_weight) x_pool_j = x_pool[edge_index[1]] x_j = x[edge_index[1]] #---Master query formation--- # NxF X_q, _ = scatter_max(x_pool_j, edge_index[0], dim=0) # NxF M_q = self.lin_q(X_q) # ExF M_q = M_q[edge_index[0].tolist()] score = self.gat_att(torch.cat((M_q, x_pool_j), dim=-1)) score = F.leaky_relu(score, self.negative_slope) score = softmax(score, edge_index[0], num_nodes=num_nodes.sum()) # Sample attention coefficients stochastically. score = F.dropout(score, p=self.dropout_att, training=self.training) # ExF v_j = x_j * score.view(-1, 1) #---Aggregation--- # NxF out = scatter_add(v_j, edge_index[0], dim=0) #---Cluster Selection # Nx1 fitness = torch.sigmoid(self.gnn_score(x=out, edge_index=edge_index)).view(-1) perm = topk(x=fitness, ratio=self.ratio, batch=batch) x = out[perm] * fitness[perm].view(-1, 1) #---Maintaining Graph Connectivity batch = batch[perm] edge_index, edge_weight = graph_connectivity(device=x.device, perm=perm, edge_index=edge_index, edge_weight=edge_weight, score=score, ratio=self.ratio, batch=batch, N=N) return x, edge_index, edge_weight, batch, perm
def get_diracs(data, N, n_diracs=1, sparse=False, flat=False, replace=True, receptive_field=7, effective_volume_range=0.1, max_iterations=20, complement=False): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if not sparse: graphcount = data.num_nodes #number of graphs in data/batch object totalnodecount = data.x.shape[1] #number of total nodes for each graph actualnodecount = 0 #cumulative number of nodes diracmatrix = torch.zeros((graphcount, totalnodecount, N), device=device) #matrix with dirac pulses for k in range(graphcount): graph_nodes = data.mask[k].sum() #number of nodes in the graph actualnodecount += graph_nodes #might not need this, we'll see probabilities = torch.ones( (graph_nodes.item(), 1), device=device) / graph_nodes #uniform probs node_distribution = OneHotCategorical( probs=probabilities.squeeze()) node_sample = node_distribution.sample(sample_shape=(N, )) node_sample = torch.cat( (node_sample, torch.zeros((N, totalnodecount - node_sample.shape[1]), device=device)), -1) #concat zeros to fit dataset shape diracmatrix[k, :] = torch.transpose( node_sample, dim0=-1, dim1=-2) #add everything to the final matrix return diracmatrix else: if not is_undirected(data.edge_index): data.edge_index = to_undirected(data.edge_index) original_batch_index = data.batch original_edge_index = add_remaining_self_loops( data.edge_index, num_nodes=data.batch.shape[0])[0] batch_index = original_batch_index graphcount = data.num_graphs batch_prime = torch.zeros(0, device=device).long() r, c = original_edge_index global_offset = 0 all_nodecounts = scatter_add( torch.ones_like(batch_index, device=device), batch_index, 0) recfield_vols = torch.zeros(graphcount, device=device) total_vols = torch.zeros(graphcount, device=device) for j in range(n_diracs): diracmatrix = torch.zeros(0, device=device) locationmatrix = torch.zeros(0, device=device).long() for k in range(graphcount): #get edges of current graph, remember to subtract offset graph_nodes = all_nodecounts[k] if graph_nodes == 0: print("all nodecounts: ", all_nodecounts) graph_edges = (batch_index[r] == k) graph_edge_index = original_edge_index[:, graph_edges] - global_offset gr, gc = graph_edge_index #get dirac randInt = np.random.choice(range(graph_nodes), N, replace=replace) node_sample = torch.zeros(N * graph_nodes, device=device) offs = torch.arange(N, device=device) * graph_nodes dirac_locations = (offs + torch.from_numpy(randInt).to(device)) node_sample[dirac_locations] = 1 #calculate receptive field volume and compare to total volume mask = get_mask(node_sample, graph_edge_index.detach(), receptive_field).float() deg_graph = degree(gr, (graph_nodes.item())) total_volume = deg_graph.sum() recfield_volume = (mask * deg_graph).sum() volume_range = recfield_volume / total_volume total_vols[k] = total_volume recfield_vols[k] = recfield_volume #if receptive field volume is less than x% of total volume, resample for iteration in range(max_iterations): randInt = np.random.choice(range(graph_nodes), N, replace=replace) node_sample = torch.zeros(N * graph_nodes, device=device) offs = torch.arange(N, device=device) * graph_nodes dirac_locations = (offs + torch.from_numpy(randInt).to(device)) node_sample[dirac_locations] = 1 mask = get_mask(node_sample, graph_edge_index, receptive_field).float() recfield_volume = (mask * deg_graph).sum() volume_range = recfield_volume / total_volume if volume_range > effective_volume_range: recfield_vols[k] = recfield_volume total_vols[k] = total_volume break dirac_locations2 = torch.from_numpy(randInt).to( device) + global_offset global_offset += graph_nodes diracmatrix = torch.cat((diracmatrix, node_sample), 0) locationmatrix = torch.cat((locationmatrix, dirac_locations2), 0) locationmatrix = diracmatrix.nonzero() if complement: return Batch(batch=batch_index, x=diracmatrix, edge_index=original_edge_index, y=data.y, locations=locationmatrix, volume_range=volume_range, recfield_vol=recfield_vols, total_vol=total_vols, complement_edge_index=data.complement_edge_index) else: return Batch(batch=batch_index, x=diracmatrix, edge_index=original_edge_index, y=data.y, locations=locationmatrix, volume_range=volume_range, recfield_vol=recfield_vols, total_vol=total_vols)
def forward(self, x, edge_idx, n, d): edge_idx, _ = add_remaining_self_loops(edge_idx) x = spmm(x, torch.ones_like(x[0]), n, d, self.weight) return self.propagate(edge_idx, x=x)
def forward(self, data, edge_dropout = None, penalty_coefficient = 0.25): x = data.x edge_index = data.edge_index batch = data.batch num_graphs = batch.max().item() + 1 row, col = edge_index total_num_edges = edge_index.shape[1] N_size = x.shape[0] if edge_dropout is not None: edge_index = dropout_adj(edge_index, edge_attr = (torch.ones(edge_index.shape[1], device=device)).long(), p = edge_dropout, force_undirected=True)[0] edge_index = add_remaining_self_loops(edge_index, num_nodes = batch.shape[0])[0] reduced_num_edges = edge_index.shape[1] current_edge_percentage = (reduced_num_edges/total_num_edges) no_loop_index,_ = remove_self_loops(edge_index) no_loop_row, no_loop_col = no_loop_index xinit= x.clone() x = x.unsqueeze(-1) mask = get_mask(x,edge_index,1).to(x.dtype) x = F.leaky_relu(self.conv1(x, edge_index))# +x x = x*mask x = self.gnorm(x) x = self.bn1(x) for conv, bn in zip(self.convs, self.bns): if(x.dim()>1): x = x+F.leaky_relu(conv(x, edge_index)) mask = get_mask(mask,edge_index,1).to(x.dtype) x = x*mask x = self.gnorm(x) x = bn(x) xpostconvs = x.detach() # x = F.leaky_relu(self.lin1(x)) x = x*mask xpostlin1 = x.detach() x = F.leaky_relu(self.lin2(x)) x = x*mask #calculate min and max batch_max = scatter_max(x, batch, 0, dim_size= N_size)[0] batch_max = torch.index_select(batch_max, 0, batch) batch_min = scatter_min(x, batch, 0, dim_size= N_size)[0] batch_min = torch.index_select(batch_min, 0, batch) #min-max normalize x = (x-batch_min)/(batch_max+1e-6-batch_min) probs=x x2 = x.detach() deg = degree(row).unsqueeze(-1) totalvol = scatter_add(deg.detach()*torch.ones_like(x, device=device), batch, 0)+1e-6 totalcard = scatter_add(torch.ones_like(x, device=device), batch, 0)+1e-6 x2 = ((x2 - torch.rand_like(x, device = device))>0).float() vol_1 = scatter_add(probs*deg, batch, 0)+1e-6 card_1 = scatter_add(probs, batch,0) set_size = scatter_add(x2, batch, 0) vol_hard = scatter_add(deg*x2, batch, 0, dim_size = batch.max().item()+1)+1e-6 total_vol_ratio = vol_hard/totalvol #calculating the terms for the expected distance between clique and graph pairwise_prodsums = torch.zeros(num_graphs, device = device) for graph in range(num_graphs): batch_graph = (batch==graph) pairwise_prodsums[graph] = (torch.conv1d(probs[batch_graph].unsqueeze(-1), probs[batch_graph].unsqueeze(-1))).sum()/2 ###calculate loss terms self_sums = scatter_add((probs*probs), batch, 0, dim_size = num_graphs) expected_weight_G = scatter_add(probs[no_loop_row]*probs[no_loop_col], batch[no_loop_row], 0, dim_size = num_graphs)/2. expected_clique_weight = (pairwise_prodsums.unsqueeze(-1) - self_sums)/1. expected_distance = (expected_clique_weight - expected_weight_G) ###useful numbers max_set_weight = (scatter_add(torch.ones_like(x)[no_loop_row], batch[no_loop_row], 0, dim_size = num_graphs)/2).squeeze(-1) set_weight = (scatter_add(x2[no_loop_row]*x2[no_loop_col], batch[no_loop_row], 0, dim_size = num_graphs)/2)+1e-6 clique_edges_hard = (set_size*(set_size-1)/2) +1e-6 clique_dist_hard = set_weight/clique_edges_hard clique_check = ((clique_edges_hard != clique_edges_hard)) setedge_check = ((set_weight != set_weight)) assert ((clique_dist_hard>=1.1).sum())<=1e-6, "Invalid set vol/clique vol ratio." ###calculate loss expected_loss = (penalty_coefficient)*expected_distance*0.5 - 0.5*expected_weight_G loss = expected_loss retdict = {} retdict["output"] = [probs.squeeze(-1),"hist"] #output retdict["Expected_cardinality"] = [card_1.mean(),"sequence"] retdict["Expected_cardinality_hist"] = [card_1,"hist"] retdict["losses histogram"] = [loss.squeeze(-1),"hist"] retdict["Set sizes"] = [set_size.squeeze(-1),"hist"] retdict["volume_hard"] = [vol_hard.mean(),"aux"] #volume2 retdict["cardinality_hard"] = [set_size[0],"sequence"] #volumeq retdict["Expected weight(G)"]= [expected_weight_G.mean(), "sequence"] retdict["Expected maximum weight"] = [expected_clique_weight.mean(),"sequence"] retdict["Expected distance"]= [expected_distance.mean(), "sequence"] retdict["Currvol/Cliquevol"] = [clique_dist_hard.mean(),'sequence'] retdict["Currvol/Cliquevol all graphs in batch"] = [clique_dist_hard.squeeze(-1),'hist'] retdict["Average ratio of total volume"]= [total_vol_ratio.mean(),'sequence'] retdict["cardinalities"] = [cardinalities.squeeze(-1),"hist"] retdict["Current edge percentage"] = [torch.tensor(current_edge_percentage),'sequence'] retdict["loss"] = [loss.mean().squeeze(),"sequence"] #final loss return retdict
def prepare_data_for_link_prediction(datalist, train_ratio=0.8, neg_to_pos_edge_ratio=1, rnd_labeled_edges=True): """For each graph it splits the edges in training and testing (both with also a negative set of examples). rnd_labeled_edges=True means that the positive and negative edges for training are choosen at random (at different epochs, the same graph can have different positive/negative edges chosen for training).""" train_data_list = [] test_data_list = [] for graph in datalist: train_graph = graph test_graph = train_graph.clone() # Create Negative edges examples ei_without_double_edges = remove_double_edges(graph.edge_index) ei_with_self_loops, _ = add_remaining_self_loops( ei_without_double_edges, num_nodes=graph.num_nodes) neg_edge_index = negative_sampling( edge_index=ei_with_self_loops, num_nodes=graph.num_nodes, num_neg_samples=neg_to_pos_edge_ratio * ei_without_double_edges.size(1), shuffle_neg_egdes=rnd_labeled_edges) num_train_pos_edges = math.floor( ei_without_double_edges.size(1) * train_ratio) num_train_neg_edges = math.floor(neg_edge_index.size(1) * train_ratio) # Split Positive edges if rnd_labeled_edges: perm = torch.randperm(ei_without_double_edges.size(1)) row, col = ei_without_double_edges[0][ perm], ei_without_double_edges[1][perm] else: row, col = ei_without_double_edges[0], ei_without_double_edges[1] train_graph.pos_edge_index = torch.stack( [row[:num_train_pos_edges], col[:num_train_pos_edges]], dim=0) test_graph.pos_edge_index = torch.stack( [row[num_train_pos_edges:], col[num_train_pos_edges:]], dim=0) # Update edge_index for message-passing for link prediction (no test edges) train_graph.edge_index = to_undirected(train_graph.pos_edge_index, num_nodes=train_graph.num_nodes) test_graph.edge_index = train_graph.edge_index # Split Negative edges if rnd_labeled_edges: perm = torch.randperm(neg_edge_index.size(1)) row, col = neg_edge_index[0][perm], neg_edge_index[1][perm] else: row, col = neg_edge_index[0], neg_edge_index[1] train_graph.neg_edge_index = torch.stack( [row[:num_train_neg_edges], col[:num_train_neg_edges]], dim=0) test_graph.neg_edge_index = torch.stack( [row[num_train_neg_edges:], col[num_train_neg_edges:]], dim=0) train_data_list.append(train_graph) test_data_list.append(test_graph) return train_data_list, test_data_list