def gat_layer(self, input, adj, genPath=False, eluF=True): N = input.size()[0] edge = adj._indices() h = torch.mm(input, self.W) h = h+self.bias # h: N x out # Self-attention on the nodes - Shared attention mechanism edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t() # edge_h: 2*D x E edge_att = self.a.mm(edge_h).squeeze() edge_e_a = self.leakyrelu(edge_att) # edge_e_a: E attetion score for each edge if genPath: with torch.no_grad(): edge_weight = edge_e_a p_a_e = edge_weight - scatter_max(edge_weight, edge[0,:], dim=0, dim_size=N)[0][edge[0,:]] p_a_e = p_a_e.exp() p_a_e = p_a_e / (scatter_add(p_a_e, edge[0,:], dim=0, dim_size=N)[edge[0,:]]\ +torch.Tensor([9e-15]).cuda()) scisp = convert.to_scipy_sparse_matrix(edge, p_a_e, N) scipy.sparse.save_npz(os.path.join(genPath, 'attmat_{:s}.npz'.format(self.layerN)), scisp) edge_e = torch.exp(edge_e_a - torch.max(edge_e_a)) # edge_e: E e_rowsum = spmm(edge, edge_e, N, torch.ones(size=(N,1)).cuda()) # e_rowsum: N x 1 edge_e = self.dropout(edge_e) # add dropout improve from 82.4 to 83.8 # edge_e: E h_prime = spmm(edge, edge_e, N, h) h_prime = h_prime.div(e_rowsum+torch.Tensor([9e-15]).cuda()) # h_prime: N x out if self.concat and eluF: return F.elu(h_prime) else: return h_prime
def batched_spmm(nzt, adj, x, m=None, n=None): """ Args: nzt: Tensor [num_edges, heads] -- non-zero tensor adj: Tensor or list(Tensor) -- adjacency matrix (COO) x: Tensor [num_nodes, channels] -- feature matrix m: int n: int """ num_edges, heads = nzt.shape[-2:] num_nodes, channels = x.shape[-2:] # preparation of data # x_ = torch.cat(heads * [x]) # duplicate x for heads times # nzt_ = nzt.view(-1) x_ = repeat(x, 't n c -> t (h n) c', h=heads) nzt_ = rearrange(nzt, 't e h -> t (h e)') if isinstance(adj, Tensor): m = maybe_num_nodes(adj[0], m) n = max(num_nodes, maybe_num_nodes(adj[1], n)) offset = torch.tensor([[m], [n]]) adj_ = torch.cat([adj + offset * i for i in range(heads)], dim=1) else: # adj is list of adjacency matrices assert heads == len( adj), "the number of heads and the number of adjacency matrices are not matched" m = max([maybe_num_nodes(adj_[0], m) for adj_ in adj]) n = max([maybe_num_nodes(adj_[1], n) for adj_ in adj]) offset = torch.tensor([[m], [n]]) adj_ = torch.cat([adj[i] + offset * i for i in range(heads)], dim=1) if len(x.shape) == 2: out = spmm(adj_, nzt_, heads * m, heads * n, x_) return out.view(-1, m, channels) # [heads, m, channels] else: _size = x_.shape[0] out = torch.stack([spmm(adj_, nzt_[i], heads * m, heads * n, x_[i]) for i in range(_size)]) return out # [batch, heads * num_nodes, channels]
def forward(self, phi_indices, phi_values, phi_inverse_indices, phi_inverse_values, feature_indices, feature_values, dropout): """ Forward propagation pass. :param phi_indices: Sparse wavelet matrix index pairs. :param phi_values: Sparse wavelet matrix values. :param phi_inverse_indices: Inverse wavelet matrix index pairs. :param phi_inverse_values: Inverse wavelet matrix values. :param feature_indices: Feature matrix index pairs. :param feature_values: Feature matrix values. :param dropout: Dropout rate. :return dropout_features: Filtered feature matrix extracted. """ rescaled_phi_indices, rescaled_phi_values = spspmm( phi_indices, phi_values, self.diagonal_weight_indices, self.diagonal_weight_filter.view(-1), self.ncount, self.ncount, self.ncount) phi_product_indices, phi_product_values = spspmm( rescaled_phi_indices, rescaled_phi_values, phi_inverse_indices, phi_inverse_values, self.ncount, self.ncount, self.ncount) filtered_features = spmm(feature_indices, feature_values, self.ncount, self.weight_matrix) localized_features = spmm(phi_product_indices, phi_product_values, self.ncount, filtered_features) dropout_features = torch.nn.functional.dropout( torch.nn.functional.relu(localized_features), training=self.training, p=dropout) return dropout_features
def forward(self, normalized_adjacency_matrix, features, dropout_rate, transform, density): """ Doing a forward pass. :param normalized_adjacency_matrix: Normalized adjacency matrix. :param features: Feature matrix. :param dropout_rate: Dropout value. :param transform: Activation function application rule. :param density: Densoty structure of the feature matrix. :return localized_features: Convolved features. """ if density: base_features = torch.mm(features, self.weight_matrix) else: base_features = spmm(features["indices"], features["values"], features["dimensions"][0], self.weight_matrix) base_features = torch.nn.functional.dropout(base_features, p=dropout_rate, training=self.training) if transform: base_features = torch.nn.functional.relu(base_features) + self.bias localized_features = base_features for iteration in range(self.iterations): localized_features = (1 - self.alpha) * spmm( normalized_adjacency_matrix["indices"], normalized_adjacency_matrix["values"], localized_features. shape[0], localized_features) + self.alpha * base_features return localized_features
def conv_f(x, stride, kernel_size, layer, transform, subspace, pad='reflection'): #weight = torch.matmul(layer['weight'], subspace).view(layer['w_shape']) #bias = torch.matmul(layer['bias'], subspace).view(layer['b_shape']) w0 = layer.weight b0 = layer.bias i, v = transform['weight']._indices(), transform['weight']._values() weight = torch_sparse.spmm(i, v, transform['w_num'], subspace).view(transform['w_shape']) # weight += w0 i, v = transform['bias']._indices(), transform['bias']._values() bias = torch_sparse.spmm(i, v, transform['b_num'], subspace).view(transform['b_shape']) # bias += b0 # weight = torch.sparse.mm(layer['weight'], subspace).view(layer['w_shape']) # bias = torch.sparse.mm(layer['bias'], subspace).view(layer['b_shape']) to_pad = int((kernel_size - 1) / 2) x = F.pad(x, (to_pad, to_pad, to_pad, to_pad), mode='reflect') x = F.conv2d(x, weight, bias, stride) return x
def forward(self, normalized_adjacency_matrix, features): """ Doing a forward pass. :param normalized_adjacency_matrix: Normalized adjacency matrix. :param features: Feature matrix. :return base_features: Convolved features. """ feature_count, _ = torch.max(features["indices"], dim=1) feature_count = feature_count + 1 base_features = spmm(features["indices"], features["values"], feature_count[0], feature_count[1], self.weight_matrix) base_features = base_features + self.bias base_features = torch.nn.functional.dropout(base_features, p=self.dropout_rate, training=self.training) base_features = torch.nn.functional.relu(base_features) for _ in range(self.iterations - 1): base_features = spmm(normalized_adjacency_matrix["indices"], normalized_adjacency_matrix["values"], base_features.shape[0], base_features.shape[0], base_features) return base_features
def forward(self, x, edge_index, edge_attr=None): """""" edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) row, col = edge_index num_nodes, num_edges, K = x.size(0), row.size(0), self.weight.size(0) if edge_attr is None: edge_attr = x.new_ones((num_edges, )) assert edge_attr.dim() == 1 and edge_attr.numel() == edge_index.size(1) deg = degree(row, num_nodes, dtype=x.dtype) # Compute normalized and rescaled Laplacian. deg = deg.pow(-0.5) deg[deg == float('inf')] = 0 lap = -deg[row] * edge_attr * deg[col] # Perform filter operation recurrently. Tx_0 = x out = torch.mm(Tx_0, self.weight[0]) if K > 1: Tx_1 = spmm(edge_index, lap, num_nodes, x) out = out + torch.mm(Tx_1, self.weight[1]) for k in range(2, K): Tx_2 = 2 * spmm(edge_index, lap, num_nodes, Tx_1) - Tx_0 out = out + torch.mm(Tx_2, self.weight[k]) Tx_0, Tx_1 = Tx_1, Tx_2 if self.bias is not None: out = out + self.bias return out
def meancurvature(pos, faces): if pos.shape[-1] != 3: raise ValueError("Vertices positions must have shape [n,3]") if faces.shape[-1] != 3: raise ValueError("Face indices must have shape [m,3]") n = pos.shape[0] stiff, mass = laplacebeltrami_FEM_v2(pos, faces) ai, av = mass mcf = tsparse.spmm(ai, torch.reciprocal(av), n, n, tsparse.spmm(*stiff, n, n, pos)) return mcf.norm(dim=-1, p=2), stiff, mass
def forward(self, x, edge_index, edge_weight=None): """""" edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) row, col = edge_index num_nodes, num_edges, K = x.size(0), row.size(0), self.K if edge_weight is None: edge_weight = x.new_ones((num_edges, )) edge_weight = edge_weight.view(-1) assert edge_weight.size(0) == edge_index.size(1) deg = degree(row, num_nodes, dtype=x.dtype) # Compute normalized and rescaled Laplacian. deg = deg.pow(-0.5) deg[deg == float('inf')] = 0 lap = -deg[row] * edge_weight * deg[col] outlist = [] # Perform filter operation recurrently. Tx_0 = x out = torch.mm(self.conv_out(Tx_0, 0), self.weight[0]) outlist.append(out) # out = torch.mm(Tx_0, self.weight[0]) if K > 1: Tx_1 = spmm(edge_index, lap, num_nodes, x) # out = out + torch.mm(Tx_1, self.weight[1]) # out = out + torch.mm(self.conv_out(Tx_1, 1), self.weight[1]) out = torch.mm(self.conv_out(Tx_1, 1), self.weight[1]) outlist.append(out) for k in range(2, K): Tx_2 = 2 * spmm(edge_index, lap, num_nodes, Tx_1) - Tx_0 # out = out + torch.mm(Tx_2, self.weight[k]) # out = out + torch.mm(self.conv_out(Tx_2, k), self.weight[k]) out = torch.mm(self.conv_out(Tx_2, k), self.weight[k]) outlist.append(out) Tx_0, Tx_1 = Tx_1, Tx_2 out = torch.stack(outlist, dim=0) out = torch.sum(out, dim=0) if self.bias is not None: out = out + self.bias return out
def forward(self, personalized_page_rank_matrix, features, dropout_rate, transform, density): """ Doing a forward pass. :param personalized_page_rank_matrix: Dense personalized pagerank matrix. :param features: Feature matrix. :param dropout_rate: Dropout value. :param transform: Activation function application rule. :param density: Densoty structure of the feature matrix. :return localized_features: Convolved features. """ if density: filtered_features = torch.mm(features, self.weight_matrix) else: filtered_features = spmm(features["indices"], features["values"], features["dimensions"][0], self.weight_matrix) filtered_features = torch.nn.functional.dropout(filtered_features, p=dropout_rate, training=self.training) if transform: filtered_features = torch.nn.functional.relu(filtered_features) localized_features = torch.mm(personalized_page_rank_matrix, filtered_features) localized_features = localized_features + self.bias return localized_features
def forward(self, x, edge_index, edge_attr=None): """""" x = x.unsqueeze(-1) if x.dim() == 1 else x edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) if edge_attr is None: edge_attr = x.new_ones((edge_index.size(1), )) assert edge_attr.dim() == 1 and edge_attr.numel() == edge_index.size(1) # Add self-loops to adjacency matrix. edge_index = add_self_loops(edge_index, x.size(0)) loop_value = x.new_full((x.size(0), ), 1 if not self.improved else 2) edge_attr = torch.cat([edge_attr, loop_value], dim=0) # Normalize adjacency matrix. row, col = edge_index deg = scatter_add(edge_attr, row, dim=0, dim_size=x.size(0)) deg = deg.pow(-0.5) deg[deg == float('inf')] = 0 edge_attr = deg[row] * edge_attr * deg[col] # Perform the convolution. out = torch.mm(x, self.weight) out = spmm(edge_index, edge_attr, out.size(0), out) if self.bias is not None: out = out + self.bias return out
def forward(self, x, edge_index, edge_attr): N, dim = x.shape # x = self.dropout(x) # adj_mat_ind, adj_mat_val = add_self_loops(edge_index, num_nodes=N)[0], edge_attr.squeeze() adj_mat_ind = add_remaining_self_loops(edge_index, num_nodes=N)[0] adj_mat_val = torch.ones(adj_mat_ind.shape[1]).to(x.device) h = torch.mm(x, self.weight) h = F.dropout(h, p=self.dropout, training=self.training) for _ in range(self.nhop - 1): adj_mat_ind, adj_mat_val = spspmm(adj_mat_ind, adj_mat_val, adj_mat_ind, adj_mat_val, N, N, N, True) adj_mat_ind, adj_mat_val = self.attention(h, adj_mat_ind, adj_mat_val) # MATRIX_MUL # laplacian matrix normalization adj_mat_val = self.normalization(adj_mat_ind, adj_mat_val, N) val_h = h # N, dim = val_h.shape # MATRIX_MUL # val_h = spmm(adj_mat_ind, F.dropout(adj_mat_val, p=self.node_dropout, training=self.training), N, N, val_h) val_h = spmm(adj_mat_ind, adj_mat_val, N, N, val_h) val_h[val_h != val_h] = 0 val_h = val_h + self.bias val_h = self.adaptive_enc(val_h) val_h = F.dropout(val_h, p=self.dropout, training=self.training) # val_h = self.activation(val_h) return val_h
def loop_sparse_attention_centrality(self, attention, idx): # O(N) implementation batch, groups, npoints, neighbors = attention.size() idx_tag = torch.tensor([[i] * neighbors for i in range(npoints)], device='cuda').flatten().unsqueeze(0) mtrx = torch.tensor([[1.] * npoints], device='cuda').T score = [] for i in range(batch): idx_flatten = idx[i].flatten().unsqueeze(0) # NK index = torch.cat([idx_tag, idx_flatten], dim=0) score_group = [] for j in range(groups): attention_flatten = attention[i][j].flatten() index_s, value_s = coalesce(index, attention_flatten, npoints, npoints) index_t, value_t = transpose(index_s, value_s, npoints, npoints) out = spmm(index_t, value_t, npoints, npoints, mtrx) score_group.append(out) if j == groups - 1: score.append(torch.cat(score_group, dim=1).unsqueeze(0)) if i == batch - 1: final_score = torch.cat(score, dim=0) # fullnl instance final_score = final_score.unsqueeze(3).permute(0, 2, 3, 1) idx_value, idx_score = final_score.topk( k=neighbors, dim=3) # B, G, 1, N -> B, G, 1, K' return idx_value, idx_score
def forward(self, phi_indices, phi_values, phi_inverse_indices, phi_inverse_values, features): """ Forward propagation pass. :param phi_indices: Sparse wavelet matrix index pairs. :param phi_values: Sparse wavelet matrix values. :param phi_inverse_indices: Inverse wavelet matrix index pairs. :param phi_inverse_values: Inverse wavelet matrix values. :param features: Feature matrix. :return localized_features: Filtered feature matrix extracted. """ rescaled_phi_indices, rescaled_phi_values = spspmm( phi_indices, phi_values, self.diagonal_weight_indices, self.diagonal_weight_filter.view(-1), self.ncount, self.ncount, self.ncount) phi_product_indices, phi_product_values = spspmm( rescaled_phi_indices, rescaled_phi_values, phi_inverse_indices, phi_inverse_values, self.ncount, self.ncount, self.ncount) filtered_features = torch.mm(features, self.weight_matrix) localized_features = spmm(phi_product_indices, phi_product_values, self.ncount, filtered_features) return localized_features
def do_conv(self, x): #orig = x.data.clone() #x = convInter[0](x, egoNets[0].edge_index.to(device).data) #print(torch.ones((egoNets[0].edge_index.shape[1]))) #output = torch_sparse.spmm(self.egoNets[0].edge_index.to(self.device), torch.ones((self.egoNets[0].edge_index.shape[1],)).to(self.device), self.numNodes, self.numNodes, x) output = torch_sparse.spmm( self.egoNets[0].ego_norm_ind.to(self.device), self.egoNets[0].ego_norm_val.to(self.device), self.numNodes, self.numNodes, x) for power in range(local_power - 1): #output = torch_sparse.spmm(self.egoNets[0].edge_index.to(self.device), torch.ones((self.egoNets[0].edge_index.shape[1],)).to(self.device), self.numNodes, self.numNodes, output) output = torch_sparse.spmm( self.egoNets[0].ego_norm_ind.to(self.device), self.egoNets[0].ego_norm_val.to(self.device), self.numNodes, self.numNodes, output) #output = convInter(x, egoNets[0].edge_index.to(device)) for i, ego in enumerate(self.egoNets): if i == 0: continue #cpu_x = x.to('cpu') #del x #torch.cuda.empty_cache() #cur_edge_index = ego.edge_index.to(device) #cur_conv = convInter[i](orig, cur_edge_index) #del cur_edge_index #torch.cuda.empty_cache() #x = cpu_x.to(device) + cur_conv #del cpu_x #torch.cuda.empty_cache() #output = output + convInter(x, ego.edge_index.to(device)) #values = ego.ego_degrees #values = torch.ones((ego.edge_index.shape[1])) #temp_out = torch_sparse.spmm(ego.edge_index.to(self.device), torch.ones((ego.edge_index.shape[1])).to(self.device), self.numNodes, self.numNodes, x) temp_out = torch_sparse.spmm(ego.ego_norm_ind.to(self.device), ego.ego_norm_val.to(self.device), self.numNodes, self.numNodes, x) for power in range(local_power - 1): #temp_out = torch_sparse.spmm(ego.edge_index.to(self.device), torch.ones((ego.edge_index.shape[1])).to(self.device), self.numNodes, self.numNodes, temp_out) temp_out = torch_sparse.spmm(ego.ego_norm_ind.to(self.device), ego.ego_norm_val.to(self.device), self.numNodes, self.numNodes, temp_out) output = output + temp_out #output = output * (1 / self.numNodes) output = torch.mul(output, self.norm_degrees) torch.cuda.empty_cache() return output
def reference(self, column_index, val, num_nodes): ''' Compute reference SpMM (neighbor aggregation) result on CPU. ''' print("# Compute reference on CPU") self.result_ref = spmm(torch.tensor(column_index, dtype=torch.int64), \ torch.FloatTensor(val), num_nodes, num_nodes, self.X)
def forward(self, edge_index, edge_attr, N): device = edge_attr.device ones = torch.ones(N, 1, device=device) rownorm = 1. / spmm(edge_index, edge_attr, N, N, ones).view(-1) col = rownorm[edge_index[1]] edge_attr_t = col * edge_attr return edge_attr_t
def forward(self, x): """""" K, lap, edge_index, num_nodes = self.K, self.lap, self.edge_index, self.num_nodes assert (num_nodes == x.shape[1]) # Perform filter operation recurrently. Tx_0 = x out = torch.matmul(Tx_0, self.weight[0]) if K > 1: # Tx_1 = spmm(edge_index, lap, num_nodes, x) Tx_1 = spmm(edge_index, lap, num_nodes, x.permute(1, 0, 2).reshape(num_nodes, -1)) Tx_1 = Tx_1.reshape(num_nodes, -1, self.in_channels).permute(1, 0, 2) # Tx_1 = batch_spmm(edge_index, lap, num_nodes, x) # sparse matrix multiplication is not compatible with multi-gpu, so we use dense mat mul # Tx_1 = torch.matmul(lap, x) out = out + torch.matmul(Tx_1, self.weight[1]) for k in range(2, K): # Tx_2 = 2 * spmm(edge_index, lap, num_nodes, Tx_1) - Tx_0 Tx_2 = 2 * spmm(edge_index, lap, num_nodes, Tx_1.permute(1, 0, 2).reshape(num_nodes, -1)) Tx_2 = Tx_2.reshape(num_nodes, -1, self.in_channels).permute( 1, 0, 2) - Tx_0 # Tx_2 = 2 * batch_spmm(edge_index, lap, num_nodes, Tx_1) - Tx_0 # sparse matrix multiplication is not compatible with multi-gpu, so we use dense mat mul # Tx_2 = 2 * torch.matmul(lap, Tx_1) - Tx_0 out = out + torch.matmul(Tx_2, self.weight[k]) Tx_0, Tx_1 = Tx_1, Tx_2 if self.bias is not None: out = out + self.bias return out
def forward(self, normalized_adjacency_matrix, features): """ Doing a forward pass. :param normalized_adjacency_matrix: Normalized adjacency matrix. :param features: Feature matrix. :return base_features: Convolved features. """ base_features = spmm(features["indices"], features["values"], features["dimensions"][0], self.weight_matrix) base_features = torch.nn.functional.dropout(base_features, p=self.dropout_rate, training=self.training) base_features = torch.nn.functional.relu(base_features) + self.bias for iteration in range(self.iterations): base_features = spmm(normalized_adjacency_matrix["indices"], normalized_adjacency_matrix["values"], base_features.shape[0], base_features) return base_features
def bn_f(x, layer, bn_module, subspace): run_mean = bn_module.running_mean run_var = bn_module.running_var #weight = torch.matmul(layer['weight'], subspace).view(layer['w_shape']) #bias = torch.matmul(layer['bias'], subspace).view(layer['b_shape']) # weight = torch.sparse.mm(layer['weight'], subspace).view(layer['w_shape']) # bias = torch.sparse.mm(layer['bias'], subspace).view(layer['b_shape']) i, v = layer['weight']._indices(), layer['weight']._values() weight = torch_sparse.spmm(i, v, layer['w_num'], subspace).view(layer['w_shape']) i, v = layer['bias']._indices(), layer['bias']._values() bias = torch_sparse.spmm(i, v, layer['b_num'], subspace).view(layer['b_shape']) y = F.batch_norm(x, run_mean, run_var, weight, bias, training=True) dummy_y = bn_module(x) return y
def test_spmm(dtype, device): row = torch.tensor([0, 0, 1, 2, 2], device=device) col = torch.tensor([0, 2, 1, 0, 1], device=device) index = torch.stack([row, col], dim=0) value = tensor([1, 2, 4, 1, 3], dtype, device) x = tensor([[1, 4], [2, 5], [3, 6]], dtype, device) out = spmm(index, value, 3, 3, x) assert out.tolist() == [[7, 16], [8, 20], [7, 19]]
def forward(self, feature, adj): # support = torch.spmm(feature, self.weight) # sparse # output = torch.spmm(adj, support) support = torch.mm(feature, self.weight) # sparse output = spmm(adj._indices(), adj._values(), adj.size(0), support) if self.bias is not None: return output + self.bias else: return output
def forward(self, edge_index, edge_attr, N): device = edge_attr.device ones = torch.ones(N, 1, device=device) rownorm = spmm(edge_index, edge_attr, N, N, ones).view(-1).pow(-0.5) row = rownorm[edge_index[0]] col = rownorm[edge_index[1]] edge_attr_t = row * edge_attr * col return edge_attr_t
def test_spmm(): row = torch.tensor([0, 0, 1, 2, 2]) col = torch.tensor([0, 2, 1, 0, 1]) index = torch.stack([row, col], dim=0) value = torch.tensor([1, 2, 4, 1, 3]) matrix = torch.tensor([[1, 4], [2, 5], [3, 6]]) out = spmm(index, value, 3, matrix) assert out.tolist() == [[7, 16], [8, 20], [7, 19]]
def forward(self, edge_index, edge_attr, N): device = edge_attr.device edge_attr_t = torch.exp(edge_attr) ones = torch.ones(N, 1, device=device) rownorm = 1. / spmm(edge_index, edge_attr_t, N, N, ones).view(-1) row = rownorm[edge_index[0]] edge_attr_t = row * edge_attr_t return edge_attr_t
def mgunpool(x, index, values, origsize, newsize): # newsize - pooled size, origsize - unpooled size, P comes as nc x n index, values = torch_sparse.coalesce(index, values, m=origsize, n=newsize) # P matrix new_feat = torch_sparse.spmm(index, values, m=origsize, n=newsize, matrix=x) # P^T X return new_feat
def forward(self, features, meshes): batch_size, nf, edges = features.shape groups = [mesh.get_groups() for mesh in meshes] #groups = [self.pad_groups(group, edges) for group in og_groups] #unroll_mat = torch.cat(groups, dim=0).view(batch_size, edges, -1) og_occu = [mesh.get_occurrences() for mesh in meshes] occurrences = [self.pad_occurrences(mesh) for mesh in og_occu] occurrences = torch.cat(occurrences, dim=0).view(batch_size, 1, -1) occurrences = occurrences.expand( (batch_size, edges, self.unroll_target)) #unroll_mat = unroll_mat / occurrences #unroll_mat = unroll_mat.to(features.device) groups = [self.sparse_pad_groups(mesh, edges) for mesh in groups] indices = torch.cat([ torch.cat([ torch.ones((1, g.indices().shape[-1]), dtype=torch.int64).to( features.device) * idx, g.indices() ], dim=0) for idx, g in enumerate(groups) ], dim=1) values = torch.cat([g.values() for g in groups], dim=0) values = values / occurrences[indices[0, :], indices[1, :], indices[2, :]] #groups = torch.sparse.FloatTensor(indices, values, (batch_size, edges, self.unroll_target)).coalesce() #return torch.matmul(features, unroll_mat) if self.result is None or self.result.shape != ( batch_size, features.shape[1], self.unroll_target): result = torch.zeros( (batch_size, features.shape[1], self.unroll_target), device=features.device) # transpose b, row, col = indices indices = torch.stack([b, col, row], dim=0) transposed_features = features.transpose(1, 2) for b_idx in range(batch_size): mask = indices[0, :] == b_idx #tmp = torch.sparse.FloatTensor(indices[1:, mask], values[mask], (self.unroll_target, edges)) #result[b_idx, :, :] = torch.sparse.mm(tmp, transposed_features[b_idx, :, :]).T result[b_idx, :, :] = spmm(indices[1:, mask], values[mask], self.unroll_target, edges, transposed_features[b_idx, :, :]).T for mesh in meshes: mesh.unroll_gemm() return result
def forward(self, phi_indices, phi_values, phi_inverse_indices, phi_inverse_values, feature_indices, feature_values, dropout): rescaled_phi_indices, rescaled_phi_values = spspmm( phi_indices, phi_values, self.diagonal_weight_indices, self.diagonal_weight_filter.view(-1), self.ncount, self.ncount, self.ncount) phi_product_indices, phi_product_values = spspmm( rescaled_phi_indices, rescaled_phi_values, phi_inverse_indices, phi_inverse_values, self.ncount, self.ncount, self.ncount) filtered_features = spmm(feature_indices, feature_values, self.ncount, self.weight_matrix) localized_features = spmm(phi_product_indices, phi_product_values, self.ncount, filtered_features) dropout_features = torch.nn.functional.dropout( torch.nn.functional.relu(localized_features), training=self.training, p=dropout) return dropout_features
def get_feat(self, graph_list): if self.feat_mode == 'dense': dense_feat = self.get_fp(graph_list) else: sp_indices, vals = self.get_fp(graph_list) w = self.input_linear.weight b = self.input_linear.bias dense_feat = spmm(sp_indices, vals, len(graph_list), w.transpose(0, 1)) + b return self.mlp(dense_feat), None
def _spmm(self, inp, params): ii, vv, size = params old_inp_size = inp.size() inp_flat_T = inp.view(-1, inp.size(-1)).t() out_flat = torch_sparse.spmm(ii, vv, m=size[0], n=size[1], matrix=inp_flat_T).t() out = out_flat.view(*old_inp_size) return out