def forward(self, data): x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch ############### x = x.float() ############### # start = time() # graph_ids, graph_node_counts = batch.unique(return_counts=True) # print(self.FC_edge_index) # edge_index = self.FC_edge_index.FC_edge_index(graph_node_counts) # print(time()-start) CoC = scatter_sum(x[:, -3:] * x[:, -5].view(-1, 1), batch, dim=0) / scatter_sum( x[:, -5].view(-1, 1), batch, dim=0) CoC = torch.cat([CoC, self.scatter_norm(x, batch)], dim=1) x = self.x_encoder(x) CoC = self.CoC_encoder(CoC) for conv, scatter_norm in zip(self.convs, self.scatter_norms): x = conv(x, edge_index) CoC = torch.cat([CoC, scatter_norm(x, batch)], dim=1) CoC = self.decoder(CoC) return CoC
def test_qtensor_scatter_idx(self): row_ids = 1024 idx = torch.randint(low=0, high=256, size=(row_ids, ), dtype=torch.int64) p = 64 x = QTensor(*torch.randn(4, row_ids, p)) x_tensor = x.stack(dim=1) assert x_tensor.size() == torch.Size([row_ids, 4, p]) x_aggr = scatter_sum(src=x_tensor, index=idx, dim=0, dim_size=x_tensor.size(0)) assert x_aggr.size() == x_tensor.size() x_aggr = x_aggr.permute(1, 0, 2) q_aggr = QTensor(*x_aggr) r = scatter_sum(x.r, idx, dim=0, dim_size=x.size(0)) i = scatter_sum(x.i, idx, dim=0, dim_size=x.size(0)) j = scatter_sum(x.j, idx, dim=0, dim_size=x.size(0)) k = scatter_sum(x.k, idx, dim=0, dim_size=x.size(0)) q_aggr2 = QTensor(r, i, j, k) assert q_aggr == q_aggr2
def contrastive_loss(encoder_output, graph_data, sim_metric): es, ps = encoder_output e_size = graph_data.x[0].size(0) ee_pos = graph_data.node_pos_index ee_neg = _contrastive_sample(ee_pos.size(1), graph_data.node_neg_index) ep_pos = graph_data.edge_pos_index ep_neg = _contrastive_sample(ep_pos.size(1), graph_data.edge_neg_index) ep1, ep2 = es.index_select(0, ee_pos[0]), es.index_select(0, ee_pos[1]) link_sim = sim_metric(ep1, ep2, flatten=True, method='exp') en1, en2 = es.index_select(0, ee_neg[0]), es.index_select(0, ee_neg[1]) non_sim = sim_metric(en1, en2, flatten=True, method='exp') pp1, pp2 = es.index_select(0, ep_pos[0]), ps.index_select(0, ep_pos[1]) pos_sim = sim_metric(pp1, pp2, flatten=True, method='exp') pn1, pn2 = es.index_select(0, ep_neg[0]), ps.index_select(0, ep_neg[1]) neg_sim = sim_metric(pn1, pn2, flatten=True, method='exp') en_sum = scatter_sum(non_sim, ee_neg[0], dim=-1, dim_size=e_size) link_loss = link_sim / (link_sim + en_sum.index_select(0, ee_pos[0])) ep_sum = scatter_sum(neg_sim, ep_neg[0], dim=-1, dim_size=e_size) ep_loss = pos_sim / (pos_sim + ep_sum.index_select(0, ep_pos[0])) loss = torch.cat([-link_loss.log(), -ep_loss.log()], dim=-1).mean() return loss
def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, unbiased: bool = True) -> torch.Tensor: if out is not None: dim_size = out.size(dim) if dim < 0: dim = src.dim() + dim count_dim = dim if index.dim() <= dim: count_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, count_dim, dim_size=dim_size) index = broadcast(index, src, dim) tmp = scatter_sum(src, index, dim, dim_size=dim_size) summ = tmp.clone() count = broadcast(count, tmp, dim).clamp(1) mean = tmp.div(count) var = (src - mean.gather(dim, index)) var = var * var var = scatter_sum(var, index, dim, out, dim_size) if unbiased: count = count.sub(1).clamp_(1) var = var.div(count) maximum = scatter_max(src, index, dim, out, dim_size)[0] minimum = scatter_min(src, index, dim, out, dim_size)[0] return torch.cat([summ,mean,var,maximum,minimum],dim=1)
def scatter_std(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, unbiased: bool = True) -> torch.Tensor: if out is not None: dim_size = out.size(dim) if dim < 0: dim = src.dim() + dim count_dim = dim if index.dim() <= dim: count_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, count_dim, dim_size=dim_size) index = broadcast(index, src, dim) tmp = scatter_sum(src, index, dim, dim_size=dim_size) count = broadcast(count, tmp, dim).clamp(1) mean = tmp.div(count) var = (src - mean.gather(dim, index)) var = var * var out = scatter_sum(var, index, dim, out, dim_size) if unbiased: count = count.sub(1).clamp_(1) out = out.div(count).sqrt() return out
def forward(self, *data): if type(data) == tuple: from torch_geometric.data import Data, Batch datalist = [] for x in data: datalist.append(Data(x=x)) data = Batch.from_data_list(datalist) try: x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch except: x = data batch = torch.zeros(x.shape[0], device=model.device, dtype=torch.int64) ############### x = x.float() ############### CoC = scatter_sum(x[:, -3:] * x[:, -5].view(-1, 1), batch, dim=0) / scatter_sum( x[:, -5].view(-1, 1), batch, dim=0) CoC = torch.cat([CoC, self.scatter_norm(x, batch)], dim=1) x = self.x_encoder(x) CoC = self.CoC_encoder(CoC) h = torch.zeros((x.shape[0], self.hcs), device=self.device) for conv in self.convs: x, CoC, h = conv(x, CoC, h, batch) CoC = torch.cat([CoC, self.scatter_norm2(x, batch)], dim=1) CoC = self.decoder(CoC) return CoC
def inverse_eig(X,edge_index,edge_weight,batch): with torch.no_grad(): Z = torch.ones((X.shape[0],1)).cuda() for _ in range(20): Z = torch_scatter.scatter_sum(edge_weight[:,None] * Z[edge_index[1]], edge_index[0],dim=0)/10 Z = Z/torch_scatter.scatter_sum(Z**2,batch,dim=0).sqrt()[batch] return 1/(1e-4 + Z)
def forward(self, data): x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch pos = x[:, -3:] x = torch.cat( [x, scatter_distribution(edge_attr, edge_index[1], dim=0)], dim=1) CoC = scatter_sum(pos * x[:, 0].view(-1, 1), batch, dim=0) / scatter_sum( x[:, 0].view(-1, 1), batch, dim=0) CoC = torch.cat([CoC, scatter_distribution(x, batch, dim=0)], dim=1) CoC_edge_index = torch.cat([ torch.arange(x.shape[0]).view(1, -1).type_as(batch), batch.view(1, -1) ], dim=0) cart = pos[CoC_edge_index[0], -3:] - CoC[CoC_edge_index[1], :3] del pos rho = torch.norm(cart, p=2, dim=-1).view(-1, 1) rho_mask = rho.squeeze() != 0 cart[rho_mask] = cart[rho_mask] / rho[rho_mask] CoC_edge_attr = torch.cat( [cart.type_as(x), rho.type_as(x), x[CoC_edge_index[0]]], dim=1) x = self.act(self.x_encoder(x)) edge_attr = self.act(self.edge_attr_encoder(edge_attr)) CoC = self.act(self.CoC_encoder(CoC)) # u = torch.zeros( (batch.max() + 1, self.hcs) ).type_as(x) h = torch.zeros((x.shape[0], self.hcs)).type_as(x) for i, op in enumerate(self.ops): x, edge_attr, CoC = op(x, edge_index, edge_attr, CoC, batch) h = self.act(self.GRUCells[i](torch.cat([ CoC[CoC_edge_index[1]], x[CoC_edge_index[0]], CoC_edge_attr ], dim=1), h)) CoC = self.act(self.lins1[i](torch.cat( [CoC, scatter_distribution(h, batch, dim=0)], dim=1))) CoC = self.act(self.lins2[i](CoC)) h = self.act(self.GRUCells[i](torch.cat([ CoC[CoC_edge_index[1]], x[CoC_edge_index[0]], CoC_edge_attr ], dim=1), h)) x = self.act(self.lins3[i](torch.cat([x, h], dim=1))) for lin in self.decoders: CoC = self.act(lin(CoC)) CoC = self.decoder(CoC) return CoC
def forward(self, data): x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch ############### x = x.float() ############### pos = x[:, -3:] graph_ids, graph_node_counts = batch.unique(return_counts=True) time_edge_index = time_edge_indeces(x[:, 1], batch) edge_attr = edge_feature_constructor(x, time_edge_index) # Define central nodes at Center of Charge: CoC = scatter_sum(pos * x[:, 0].view(-1, 1), batch, dim=0) / scatter_sum( x[:, 0].view(-1, 1), batch, dim=0) # Define edge_attr for those edges: cart = pos[:, -3:] - CoC[batch, :3] del pos rho = torch.norm(cart, p=2, dim=-1).view(-1, 1) rho_mask = rho.squeeze() != 0 cart[rho_mask] = cart[rho_mask] / rho[rho_mask] CoC_edge_attr = torch.cat([cart.type_as(x), rho.type_as(x)], dim=1) x = torch.cat([ x, x_feature_constructor(x, graph_node_counts), edge_attr, x[time_edge_index[0]], CoC_edge_attr, CoC[batch] ], dim=1) CoC = torch.cat([CoC, scatter_distribution(x, batch, dim=0)], dim=1) x = self.x_encoder(x) CoC = self.CoC_encoder(CoC) h = torch.zeros((x.shape[0], self.hcs)).type_as(x) for i in range(N_metalayers): x, CoC, h = self.convs[i](x, CoC, h, batch) CoC = torch.cat([CoC, scatter_distribution(x, batch, dim=0)], dim=1) CoC = self.decoder(CoC) # out = [] # for mlp in self.decoders: # out.append(mlp(CoC)) # CoC = torch.cat(out,dim=1) return CoC
def forward(self, x, x_clique, tree_edge_index, atom2clique_index, u, tree_batch): row, col = tree_edge_index out = scatter_sum(x_clique[row], col, dim=0, dim_size=x_clique.size(0)) out = self.mlp1(out) row_assign, col_assign = atom2clique_index node_info = scatter_sum(x[row_assign], col_assign, dim=0, dim_size=x_clique.size(0)) node_info = self.mlp2(node_info) ### Step 4 out = torch.cat([node_info, x_clique, out, u[tree_batch]], dim=1) return self.subgraph_mlp(out) ### Step 5
def forward(self, *data): # print(data) # print("================================") # Print("Here") if type(data) == tuple: # print("here", len(data),data[0].shape) from torch_geometric.data import Data, Batch datalist = [] for x in data: if x.dim() > 2: for tmp_x in x: datalist.append(Data(x=tmp_x.squeeze())) else: datalist.append(Data(x=x.squeeze())) # datalist.append(Data(x=x)) data = Batch.from_data_list(datalist) try: x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch except AttributeError: x = data batch = torch.zeros(x.shape[0], device=model.device, dtype=torch.int64) # Print("To Here") ############### x = x.float() ############### # print(x.shape) # print(x) CoC = scatter_sum(x[:, -3:] * x[:, -5].unsqueeze(-1), batch, dim=0) / scatter_sum( x[:, -5].unsqueeze(-1), batch, dim=0) CoC[CoC.isnan()] = 0 CoC = torch.cat([CoC, self.scatter_norm(x, batch)], dim=1) x = self.x_encoder(x) CoC = self.CoC_encoder(CoC) h = torch.zeros((x.shape[0], self.hcs), device=self.device) out = CoC.clone() #version a for conv in self.convs: x, CoC, h = conv(x, CoC, h, batch) out = torch.cat([out, CoC.clone()], dim=1) #version a # CoC = torch.cat([CoC,self.scatter_norm2(x, batch, CoC)],dim=1) CoC = torch.cat([out, self.scatter_norm2(x, batch, out)], dim=1) #version a CoC = self.decoder(CoC) return CoC
def return_CoC_and_edge_attr(self, x, batch): pos = x[:, -3:] charge = x[:, 0].view(-1, 1) # Define central nodes at Center of Charge: CoC = scatter_sum(pos * charge, batch, dim=0) / scatter_sum( charge, batch, dim=0) # Define edge_attr for those edges: cart = pos - CoC[batch] rho = torch.norm(cart, p=2, dim=1).view(-1, 1) rho_mask = rho.squeeze() != 0 cart[rho_mask] = cart[rho_mask] / rho[rho_mask] CoC_edge_attr = torch.cat([cart.type_as(x), rho.type_as(x)], dim=1) return CoC, CoC_edge_attr
def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, unbiased: bool = True) -> torch.Tensor: if out is not None: dim_size = out.size(dim) if dim < 0: dim = src.dim() + dim count_dim = dim if index.dim() <= dim: count_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, count_dim, dim_size=dim_size) index = broadcast(index, src, dim) tmp = scatter_sum(src, index, dim, dim_size=dim_size) count = broadcast(count, tmp, dim).clamp(1) mean = tmp.div(count) src_minus_mean = (src - mean.gather(dim, index)) var = src_minus_mean * src_minus_mean var = scatter_sum(var, index, dim, out, dim_size) if unbiased: count = count.sub(1).clamp_(1) var = var.div(count) skew = src_minus_mean * src_minus_mean * src_minus_mean / ( var.gather(dim, index) + 1e-7)**(1.5) kurtosis = (src_minus_mean * src_minus_mean * src_minus_mean * src_minus_mean) / (var * var + 1e-7).gather(dim, index) skew = scatter_sum(skew, index, dim, out, dim_size) kurtosis = scatter_sum(kurtosis, index, dim, out, dim_size) skew = skew.div(count) kurtosis = kurtosis.div(count) maximum = scatter_max(src, index, dim, out, dim_size)[0] minimum = scatter_min(src, index, dim, out, dim_size)[0] return torch.cat([mean, var, skew, kurtosis, maximum, minimum], dim=1)
def scatter_mul(src, edge_index, edge_attr=None, dim=0): scatter_src = src.index_select(dim, edge_index[0]) if edge_attr is not None: assert edge_index.size(1) == edge_attr.size(0) scatter_src = scatter_src * edge_attr.long() output = scatter_sum(scatter_src, edge_index[1], dim) return output
def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, eps: float = 1e-12) -> torch.Tensor: if not torch.is_floating_point(src): raise ValueError('`scatter_logsumexp` can only be computed over ' 'tensors with floating point data types.') index = broadcast(index, src, dim) if out is not None: dim_size = out.size(dim) else: if dim_size is None: dim_size = int(index.max()) + 1 size = src.size() size[dim] = dim_size max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype, device=src.device) scatter_max(src, index, dim, max_value_per_index, dim_size)[0] max_per_src_element = max_value_per_index.gather(dim, index) recentered_scores = src - max_per_src_element if out is not None: out = out.sub_(max_per_src_element).exp_() sum_per_index = scatter_sum(recentered_scores.exp_(), index, dim, out, dim_size) return sum_per_index.add_(eps).log_().add_(max_value_per_index)
def forward(self, data): x = scatter_sum(data.h, data.batch, dim=0) if 'vn_h' in data: x += data.vn_h data.vn_h = self.mlp(x) data.h += data.vn_h[data.batch] return data
def weighted_dimwise_median(A: torch.sparse.FloatTensor, x: torch.Tensor, **kwargs) -> torch.Tensor: """A weighted dimension-wise Median aggregation. Parameters ---------- A : torch.sparse.FloatTensor Sparse [n, n] tensor of the weighted/normalized adjacency matrix x : torch.Tensor Dense [n, d] tensor containing the node attributes/embeddings Returns ------- torch.Tensor The new embeddings [n, d] """ if not A.is_cuda: return weighted_dimwise_median_cpu(A, x, **kwargs) assert A.is_sparse N, D = x.shape median_idx = custom_cuda_kernels.dimmedian_idx(x, A) col_idx = torch.arange(D, device=A.device).view(1, -1).expand(N, D) x_selected = x[median_idx, col_idx] a_row_sum = torch_scatter.scatter_sum(A._values(), A._indices()[0], dim=-1).view(-1, 1).expand(N, D) return a_row_sum * x_selected
def forward(self, x, edge_index, edge_attr, u, batch): row, col = edge_index out = torch.cat([x[row], edge_attr], dim=1) out = scatter_sum(out, col, dim=0, dim_size=x.size(0)) out = self.node_mlp_1(out) ### Step 2 out = torch.cat([x, out, u[batch]], dim=1) return self.node_mlp_2(out) ### Step 3
def forward(self, inputs: ElementsToSummaryRepresentationInput) -> torch.Tensor: query = self.__query_layer(inputs) # [num_graphs, H] query_per_node = query[inputs.element_to_sample_map] # [num_vertices, H] values = self.__value_layer(inputs.element_embeddings) # [num_vertices, H] query_per_node = values.reshape( (query_per_node.shape[0], self.__num_heads, query_per_node.shape[1] // self.__num_heads) ) values = values.reshape( (values.shape[0], self.__num_heads, values.shape[1] // self.__num_heads) ) attention_scores = torch.einsum( "vkh,vkh->vk", query_per_node, values ) # [num_vertices, num_heads] attention_probs = torch.exp( scatter_log_softmax(attention_scores, index=inputs.element_to_sample_map, dim=0, eps=0) ) # [num_vertices, num_heads] outputs = attention_probs.unsqueeze(-1) * inputs.element_embeddings.unsqueeze( 1 ) # [num_vertices, num_heads, D'] per_graph_outputs = scatter_sum( outputs, index=inputs.element_to_sample_map, dim=0, dim_size=inputs.num_samples ) # [num_graphs, num_heads, D'] per_graph_outputs = per_graph_outputs.reshape( (per_graph_outputs.shape[0], -1) ) # [num_graphs, num_heads * D'] return self.__output_layer(per_graph_outputs) # [num_graphs, D']
def forward(self, inputs: ElementsToSummaryRepresentationInput) -> torch.Tensor: weights = torch.sigmoid( self.__weights_layer(inputs.element_embeddings).squeeze(-1) ) # [num_vertices] return scatter_sum( inputs.element_embeddings * weights.unsqueeze(-1), index=inputs.element_to_sample_map, dim=0, dim_size=inputs.num_samples, ) # [num_graphs, D']
def scatter_softmax( src: Tensor, index: Tensor, dim: int, dim_size: Optional[int] = None ) -> Tensor: if src.numel() == 0: return src slice_tuple = (slice(None),) * dim + (index,) expand_args = src.size()[:dim] + (-1,) src = src - scatter_max(src, index, dim, dim_size=dim_size)[0][slice_tuple] exp = torch.exp(src) return exp / scatter_sum(exp, index, dim, dim_size=dim_size)[slice_tuple]
def aggregate(self, inputs: torch.Tensor, index: torch.Tensor, dim_size: Optional[int] = None) -> torch.Tensor: out = scatter_softmax(inputs * self.beta, index, dim=self.node_dim) out = scatter_sum(inputs * out, index, dim=self.node_dim, dim_size=dim_size) return out
def forward(self, x, edge_index, edge_attr, batch): # x: [N, h], where N is the number of nodes. # edge_index: [2, E] with max entry N - 1. # edge_attr: [E, F_e] # u: [B, F_u] (N/A) # batch: [N] with max entry B - 1. # source, target = edge_index _, col = edge_index out = self.node_mlp_1(edge_attr) out = scatter_sum(out, col, dim=0, dim_size=x.size(0)) return self.node_mlp_2(out)
def forward(self,X,edge_index,edge_weight,batch): X = self.start(X) for idx,m in enumerate(self.intermediate): Update = m[0](X) + torch_scatter.scatter_sum(edge_weight[:,None] * m[1](X)[edge_index[1]], edge_index[0],dim=0) if self.res: X = X + Update else: X = Update X = torch.nn.LeakyReLU()(self.norm[idx](X,edge_index,edge_weight,batch)) if torch.isnan(X).any(): raise ValueError return self.finish(X)
def forward(self,X,edge_index,edge_weight,batch): # Project to int_channels X = self.start(X) # Run through GraphConv layers for idx,m in enumerate(self.intermediate): X = m[0](X) + torch_scatter.scatter_sum(edge_weight[:,None] * m[1](X)[edge_index[1]], edge_index[0],dim=0) X = torch.nn.LeakyReLU()(self.bn[idx](X)) # Project to out_channels return self.finish(X)
def forward(self,data): x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch ############### x = x.float() ############### CoC = scatter_sum( x[:,-3:]*x[:,-5].view(-1,1), batch, dim=0) / scatter_sum(x[:,-5].view(-1,1), batch, dim=0) CoC = torch.cat([CoC,self.scatter_norm(x,batch)],dim=1) x = self.x_encoder(x) CoC = self.CoC_encoder(CoC) CoC_x = torch.cat([CoC,x],dim=0) edge_index = self.return_edge_index(batch) CoC_x = self.TConv(CoC_x, edge_index) CoC = self.decoder(CoC_x[batch.unique()]) return CoC
def forward(self, data): # device = self.device # mode = self.mode k = self.k device = self.device pos_idx = self.pos_idx x, edge_index, batch = data.x, data.edge_index, data.batch edge_index = knn_graph(x=x[:, pos_idx], k=k, batch=batch).to(device) x = self.GGconv1(x, edge_index) x = self.relu(x) x = self.nn1(x) x = self.relu(x) y = self.resblock1(x) x = x + y z = self.resblock2(x) x = x + z del y, z x = self.nn2(x) x = self.relu(x) x = self.GGconv2(x, edge_index) x = self.relu(x) p = self.resblock3(x) x = x + p o = self.resblock4(x) x = x + o del p, o x = self.nn3(x) x = self.relu(x) a, _ = scatter_max(x, batch, dim=0) b, _ = scatter_min(x, batch, dim=0) c = scatter_sum(x, batch, dim=0) d = scatter_mean(x, batch, dim=0) x = torch.cat((a, b, c, d), dim=1) # print ("cat size",x.size()) del a, b, c, d x = self.nncat(x) x = self.relu(x) # if(torch.sum(torch.isnan(x)) != 0): # print('NAN ENCOUNTERED AT NN2') # print ("xsize %s batchsize %s a size %s b size %s y size %s end forward" %(x.size(),batch.size(),a.size(),b.size(),data.y[:,0].size())) return x
def test_scatter_batch_idx(self): n_graphs = 128 n_nodes = 2048 idx = torch.randint(low=0, high=n_graphs, size=(n_nodes, ), dtype=torch.int64) p = 64 x = QTensor(*torch.randn(4, n_nodes, p)) x_tensor = x.stack(dim=1) assert x_tensor.size() == torch.Size([n_nodes, 4, p]) x_aggr = scatter_sum(src=x_tensor, index=idx, dim=0) x_aggr2 = global_add_pool(x_tensor, batch=idx) assert torch.allclose(x_aggr, x_aggr2) x_aggr = x_aggr.permute(1, 0, 2) q_aggr = QTensor(*x_aggr) r = scatter_sum(x.r, idx, dim=0) i = scatter_sum(x.i, idx, dim=0) j = scatter_sum(x.j, idx, dim=0) k = scatter_sum(x.k, idx, dim=0) q_aggr2 = QTensor(r, i, j, k) assert q_aggr == q_aggr2 assert torch.allclose(x_aggr[0], r) assert torch.allclose(x_aggr[1], i) assert torch.allclose(x_aggr[2], j) assert torch.allclose(x_aggr[3], k) r1 = global_add_pool(x.r, idx) i1 = global_add_pool(x.i, idx) j1 = global_add_pool(x.j, idx) k1 = global_add_pool(x.k, idx) q_aggr3 = QTensor(r1, i1, j1, k1) assert q_aggr == q_aggr2 == q_aggr3
def get_metrics(model,test_loader,k=64): # Layerwise MAD model_mad = [] # Layerwise AggNorm model_agg = [] # Normalized Rayleigh model_ray = [] # S1 parameter model_s = [] MAD,Agg,Ray,S = torch.zeros(k),torch.zeros(k),torch.zeros(k),torch.zeros(k) # Iterate over dataset and average metrics for idx,data in enumerate(test_loader): X = data.x.cuda() edge_index,edge_weight = data.edge_index.cuda(),data.edge_weight.cuda() batch = data.batch.cuda() model.eval() X = model.start(X) # Iterate over model layers for jdx,m in enumerate(model.intermediate): Update = m[0](X) + torch_scatter.scatter_sum(edge_weight[:,None] * m[1](X)[edge_index[1]], edge_index[0],dim=0) if model.res: X = X + Update else: X = Update X = torch.nn.LeakyReLU()(model.norm[idx](X,edge_index,edge_weight,batch)) # Fetch S1 S[jdx] += torch.sigmoid(model.norm[jdx].s1).item() # Compute MAD MAD[jdx] += batched_MAD(X,data.edge_index.cuda(),data.edge_weight.cuda()).mean().item() # Compute AggNorm Agg[jdx] += batched_agg(X,data.edge_index.cuda(),data.edge_weight.cuda(),batch).item() #Compute Normalized Rayleigh Ray[jdx] += rayleigh_quotient(X,edge_index,edge_weight,batch,data.eig_max.cuda(),data.eig_min.cuda()).mean().item() model_mad.append(MAD/(idx+1)) model_agg.append(Agg/(idx+1)) model_ray.append(Ray/(idx+1)) model_s.append(S/(idx+1)) # Return metrics return (model_mad,model_agg,model_ray,model_s)
def scatter_mean( src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: out = scatter_sum(src, index, dim, out, dim_size) dim_size = out.size(dim) index_dim = dim if index_dim < 0: index_dim = index_dim + src.dim() if index.dim() <= index_dim: index_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, index_dim, None, dim_size) count_ret = count.clone() count.clamp_(1) count = broadcast(count, out, dim) out.div_(count) return out, count_ret