示例#1
0
    def gat_layer(self, input, adj, genPath=False, eluF=True):
        N = input.size()[0]
        edge = adj._indices()
        h = torch.mm(input, self.W)
        h = h+self.bias                # h: N x out

        # Self-attention on the nodes - Shared attention mechanism
        edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t()     # edge_h: 2*D x E
        edge_att = self.a.mm(edge_h).squeeze()
        edge_e_a = self.leakyrelu(edge_att)     # edge_e_a: E   attetion score for each edge
        if genPath:
            with torch.no_grad():
                edge_weight = edge_e_a
                p_a_e = edge_weight - scatter_max(edge_weight, edge[0,:], dim=0, dim_size=N)[0][edge[0,:]]
                p_a_e = p_a_e.exp()
                p_a_e = p_a_e / (scatter_add(p_a_e, edge[0,:], dim=0, dim_size=N)[edge[0,:]]\
                                    +torch.Tensor([9e-15]).cuda())
                
                scisp = convert.to_scipy_sparse_matrix(edge, p_a_e, N)
                scipy.sparse.save_npz(os.path.join(genPath, 'attmat_{:s}.npz'.format(self.layerN)), scisp)

        edge_e = torch.exp(edge_e_a - torch.max(edge_e_a))                  # edge_e: E
        e_rowsum = spmm(edge, edge_e, N, torch.ones(size=(N,1)).cuda())     # e_rowsum: N x 1
        edge_e = self.dropout(edge_e)       # add dropout improve from 82.4 to 83.8
        # edge_e: E
        
        h_prime = spmm(edge, edge_e, N, h)
        h_prime = h_prime.div(e_rowsum+torch.Tensor([9e-15]).cuda())        # h_prime: N x out
        
        if self.concat and eluF:
            return F.elu(h_prime)
        else:
            return h_prime
示例#2
0
文件: linalg.py 项目: TYSSSY/Apb-gcn
def batched_spmm(nzt, adj, x, m=None, n=None):
    """
    Args:
        nzt: Tensor [num_edges, heads]    -- non-zero tensor
        adj: Tensor or list(Tensor)       -- adjacency matrix (COO)
        x:   Tensor [num_nodes, channels] -- feature matrix
        m:   int
        n:   int
    """
    num_edges, heads = nzt.shape[-2:]
    num_nodes, channels = x.shape[-2:]
    # preparation of data
    # x_ = torch.cat(heads * [x])  # duplicate x for heads times
    # nzt_ = nzt.view(-1)
    x_ = repeat(x, 't n c -> t (h n) c', h=heads)
    nzt_ = rearrange(nzt, 't e h -> t (h e)')
    if isinstance(adj, Tensor):
        m = maybe_num_nodes(adj[0], m)
        n = max(num_nodes, maybe_num_nodes(adj[1], n))
        offset = torch.tensor([[m], [n]])
        adj_ = torch.cat([adj + offset * i for i in range(heads)], dim=1)
    else:  # adj is list of adjacency matrices
        assert heads == len(
            adj), "the number of heads and the number of adjacency matrices are not matched"
        m = max([maybe_num_nodes(adj_[0], m) for adj_ in adj])
        n = max([maybe_num_nodes(adj_[1], n) for adj_ in adj])
        offset = torch.tensor([[m], [n]])
        adj_ = torch.cat([adj[i] + offset * i for i in range(heads)], dim=1)
    if len(x.shape) == 2:
        out = spmm(adj_, nzt_, heads * m, heads * n, x_)
        return out.view(-1, m, channels)  # [heads, m, channels]
    else:
        _size = x_.shape[0]
        out = torch.stack([spmm(adj_, nzt_[i], heads * m, heads * n, x_[i]) for i in range(_size)])
        return out  # [batch, heads * num_nodes, channels]
    def forward(self, phi_indices, phi_values, phi_inverse_indices,
                phi_inverse_values, feature_indices, feature_values, dropout):
        """
        Forward propagation pass.
        :param phi_indices: Sparse wavelet matrix index pairs.
        :param phi_values: Sparse wavelet matrix values.
        :param phi_inverse_indices: Inverse wavelet matrix index pairs.
        :param phi_inverse_values: Inverse wavelet matrix values.
        :param feature_indices: Feature matrix index pairs.
        :param feature_values: Feature matrix values.
        :param dropout: Dropout rate.
        :return dropout_features: Filtered feature matrix extracted.
        """
        rescaled_phi_indices, rescaled_phi_values = spspmm(
            phi_indices, phi_values, self.diagonal_weight_indices,
            self.diagonal_weight_filter.view(-1), self.ncount, self.ncount,
            self.ncount)

        phi_product_indices, phi_product_values = spspmm(
            rescaled_phi_indices, rescaled_phi_values, phi_inverse_indices,
            phi_inverse_values, self.ncount, self.ncount, self.ncount)

        filtered_features = spmm(feature_indices, feature_values, self.ncount,
                                 self.weight_matrix)

        localized_features = spmm(phi_product_indices, phi_product_values,
                                  self.ncount, filtered_features)

        dropout_features = torch.nn.functional.dropout(
            torch.nn.functional.relu(localized_features),
            training=self.training,
            p=dropout)
        return dropout_features
示例#4
0
 def forward(self, normalized_adjacency_matrix, features, dropout_rate,
             transform, density):
     """
     Doing a forward pass.
     :param normalized_adjacency_matrix: Normalized adjacency matrix.
     :param features: Feature matrix.
     :param dropout_rate: Dropout value.
     :param transform: Activation function application rule.
     :param density: Densoty structure of the feature matrix.
     :return localized_features: Convolved features.
     """
     if density:
         base_features = torch.mm(features, self.weight_matrix)
     else:
         base_features = spmm(features["indices"], features["values"],
                              features["dimensions"][0], self.weight_matrix)
     base_features = torch.nn.functional.dropout(base_features,
                                                 p=dropout_rate,
                                                 training=self.training)
     if transform:
         base_features = torch.nn.functional.relu(base_features) + self.bias
     localized_features = base_features
     for iteration in range(self.iterations):
         localized_features = (1 - self.alpha) * spmm(
             normalized_adjacency_matrix["indices"],
             normalized_adjacency_matrix["values"], localized_features.
             shape[0], localized_features) + self.alpha * base_features
     return localized_features
示例#5
0
def conv_f(x,
           stride,
           kernel_size,
           layer,
           transform,
           subspace,
           pad='reflection'):
    #weight = torch.matmul(layer['weight'], subspace).view(layer['w_shape'])
    #bias = torch.matmul(layer['bias'], subspace).view(layer['b_shape'])
    w0 = layer.weight
    b0 = layer.bias

    i, v = transform['weight']._indices(), transform['weight']._values()
    weight = torch_sparse.spmm(i, v, transform['w_num'],
                               subspace).view(transform['w_shape'])
    # weight += w0

    i, v = transform['bias']._indices(), transform['bias']._values()
    bias = torch_sparse.spmm(i, v, transform['b_num'],
                             subspace).view(transform['b_shape'])
    # bias += b0
    # weight = torch.sparse.mm(layer['weight'], subspace).view(layer['w_shape'])
    # bias = torch.sparse.mm(layer['bias'], subspace).view(layer['b_shape'])

    to_pad = int((kernel_size - 1) / 2)

    x = F.pad(x, (to_pad, to_pad, to_pad, to_pad), mode='reflect')
    x = F.conv2d(x, weight, bias, stride)

    return x
示例#6
0
    def forward(self, normalized_adjacency_matrix, features):
        """
        Doing a forward pass.
        :param normalized_adjacency_matrix: Normalized adjacency matrix.
        :param features: Feature matrix.
        :return base_features: Convolved features.
        """
        feature_count, _ = torch.max(features["indices"], dim=1)
        feature_count = feature_count + 1
        base_features = spmm(features["indices"], features["values"],
                             feature_count[0], feature_count[1],
                             self.weight_matrix)

        base_features = base_features + self.bias

        base_features = torch.nn.functional.dropout(base_features,
                                                    p=self.dropout_rate,
                                                    training=self.training)

        base_features = torch.nn.functional.relu(base_features)
        for _ in range(self.iterations - 1):
            base_features = spmm(normalized_adjacency_matrix["indices"],
                                 normalized_adjacency_matrix["values"],
                                 base_features.shape[0],
                                 base_features.shape[0], base_features)
        return base_features
    def forward(self, x, edge_index, edge_attr=None):
        """"""
        edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)

        row, col = edge_index
        num_nodes, num_edges, K = x.size(0), row.size(0), self.weight.size(0)

        if edge_attr is None:
            edge_attr = x.new_ones((num_edges, ))
        assert edge_attr.dim() == 1 and edge_attr.numel() == edge_index.size(1)

        deg = degree(row, num_nodes, dtype=x.dtype)

        # Compute normalized and rescaled Laplacian.
        deg = deg.pow(-0.5)
        deg[deg == float('inf')] = 0
        lap = -deg[row] * edge_attr * deg[col]

        # Perform filter operation recurrently.
        Tx_0 = x
        out = torch.mm(Tx_0, self.weight[0])

        if K > 1:
            Tx_1 = spmm(edge_index, lap, num_nodes, x)
            out = out + torch.mm(Tx_1, self.weight[1])

        for k in range(2, K):
            Tx_2 = 2 * spmm(edge_index, lap, num_nodes, Tx_1) - Tx_0
            out = out + torch.mm(Tx_2, self.weight[k])
            Tx_0, Tx_1 = Tx_1, Tx_2

        if self.bias is not None:
            out = out + self.bias

        return out
示例#8
0
def meancurvature(pos, faces):
  if pos.shape[-1] != 3:
    raise ValueError("Vertices positions must have shape [n,3]")

  if faces.shape[-1] != 3:
    raise ValueError("Face indices must have shape [m,3]") 

  n = pos.shape[0]
  stiff, mass = laplacebeltrami_FEM_v2(pos, faces)
  ai, av = mass
  mcf = tsparse.spmm(ai, torch.reciprocal(av), n, n, tsparse.spmm(*stiff, n, n, pos))
  return mcf.norm(dim=-1, p=2), stiff, mass
示例#9
0
    def forward(self, x, edge_index, edge_weight=None):
        """"""
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)

        row, col = edge_index
        num_nodes, num_edges, K = x.size(0), row.size(0), self.K

        if edge_weight is None:
            edge_weight = x.new_ones((num_edges, ))
        edge_weight = edge_weight.view(-1)
        assert edge_weight.size(0) == edge_index.size(1)

        deg = degree(row, num_nodes, dtype=x.dtype)

        # Compute normalized and rescaled Laplacian.
        deg = deg.pow(-0.5)
        deg[deg == float('inf')] = 0
        lap = -deg[row] * edge_weight * deg[col]

        outlist = []

        # Perform filter operation recurrently.
        Tx_0 = x
        out = torch.mm(self.conv_out(Tx_0, 0), self.weight[0])
        outlist.append(out)
        # out = torch.mm(Tx_0, self.weight[0])

        if K > 1:
            Tx_1 = spmm(edge_index, lap, num_nodes, x)
            # out = out + torch.mm(Tx_1, self.weight[1])
            # out = out + torch.mm(self.conv_out(Tx_1, 1), self.weight[1])
            out = torch.mm(self.conv_out(Tx_1, 1), self.weight[1])
            outlist.append(out)

        for k in range(2, K):
            Tx_2 = 2 * spmm(edge_index, lap, num_nodes, Tx_1) - Tx_0
            # out = out + torch.mm(Tx_2, self.weight[k])
            # out = out + torch.mm(self.conv_out(Tx_2, k), self.weight[k])
            out = torch.mm(self.conv_out(Tx_2, k), self.weight[k])
            outlist.append(out)
            Tx_0, Tx_1 = Tx_1, Tx_2

        out = torch.stack(outlist, dim=0)
        out = torch.sum(out, dim=0)

        if self.bias is not None:
            out = out + self.bias

        return out
示例#10
0
 def forward(self, personalized_page_rank_matrix, features, dropout_rate,
             transform, density):
     """
     Doing a forward pass.
     :param personalized_page_rank_matrix: Dense personalized pagerank matrix.
     :param features: Feature matrix.
     :param dropout_rate: Dropout value.
     :param transform: Activation function application rule.
     :param density: Densoty structure of the feature matrix.
     :return localized_features: Convolved features.
     """
     if density:
         filtered_features = torch.mm(features, self.weight_matrix)
     else:
         filtered_features = spmm(features["indices"], features["values"],
                                  features["dimensions"][0],
                                  self.weight_matrix)
     filtered_features = torch.nn.functional.dropout(filtered_features,
                                                     p=dropout_rate,
                                                     training=self.training)
     if transform:
         filtered_features = torch.nn.functional.relu(filtered_features)
     localized_features = torch.mm(personalized_page_rank_matrix,
                                   filtered_features)
     localized_features = localized_features + self.bias
     return localized_features
示例#11
0
    def forward(self, x, edge_index, edge_attr=None):
        """"""
        x = x.unsqueeze(-1) if x.dim() == 1 else x

        edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
        if edge_attr is None:
            edge_attr = x.new_ones((edge_index.size(1), ))
        assert edge_attr.dim() == 1 and edge_attr.numel() == edge_index.size(1)

        # Add self-loops to adjacency matrix.
        edge_index = add_self_loops(edge_index, x.size(0))
        loop_value = x.new_full((x.size(0), ), 1 if not self.improved else 2)
        edge_attr = torch.cat([edge_attr, loop_value], dim=0)

        # Normalize adjacency matrix.
        row, col = edge_index
        deg = scatter_add(edge_attr, row, dim=0, dim_size=x.size(0))
        deg = deg.pow(-0.5)
        deg[deg == float('inf')] = 0
        edge_attr = deg[row] * edge_attr * deg[col]

        # Perform the convolution.
        out = torch.mm(x, self.weight)
        out = spmm(edge_index, edge_attr, out.size(0), out)

        if self.bias is not None:
            out = out + self.bias

        return out
示例#12
0
    def forward(self, x, edge_index, edge_attr):
        N, dim = x.shape
        # x = self.dropout(x)

        # adj_mat_ind, adj_mat_val = add_self_loops(edge_index, num_nodes=N)[0], edge_attr.squeeze()
        adj_mat_ind = add_remaining_self_loops(edge_index, num_nodes=N)[0]
        adj_mat_val = torch.ones(adj_mat_ind.shape[1]).to(x.device)

        h = torch.mm(x, self.weight)
        h = F.dropout(h, p=self.dropout, training=self.training)
        for _ in range(self.nhop - 1):
            adj_mat_ind, adj_mat_val = spspmm(adj_mat_ind, adj_mat_val,
                                              adj_mat_ind, adj_mat_val, N, N,
                                              N, True)

        adj_mat_ind, adj_mat_val = self.attention(h, adj_mat_ind, adj_mat_val)

        # MATRIX_MUL
        # laplacian matrix normalization
        adj_mat_val = self.normalization(adj_mat_ind, adj_mat_val, N)

        val_h = h
        # N, dim = val_h.shape

        # MATRIX_MUL
        # val_h = spmm(adj_mat_ind, F.dropout(adj_mat_val, p=self.node_dropout, training=self.training), N, N, val_h)
        val_h = spmm(adj_mat_ind, adj_mat_val, N, N, val_h)

        val_h[val_h != val_h] = 0
        val_h = val_h + self.bias
        val_h = self.adaptive_enc(val_h)
        val_h = F.dropout(val_h, p=self.dropout, training=self.training)
        # val_h = self.activation(val_h)
        return val_h
示例#13
0
    def loop_sparse_attention_centrality(self, attention, idx):
        # O(N) implementation

        batch, groups, npoints, neighbors = attention.size()
        idx_tag = torch.tensor([[i] * neighbors for i in range(npoints)],
                               device='cuda').flatten().unsqueeze(0)
        mtrx = torch.tensor([[1.] * npoints], device='cuda').T

        score = []

        for i in range(batch):
            idx_flatten = idx[i].flatten().unsqueeze(0)  # NK
            index = torch.cat([idx_tag, idx_flatten], dim=0)
            score_group = []
            for j in range(groups):
                attention_flatten = attention[i][j].flatten()
                index_s, value_s = coalesce(index, attention_flatten, npoints,
                                            npoints)
                index_t, value_t = transpose(index_s, value_s, npoints,
                                             npoints)
                out = spmm(index_t, value_t, npoints, npoints, mtrx)
                score_group.append(out)
                if j == groups - 1:
                    score.append(torch.cat(score_group, dim=1).unsqueeze(0))
            if i == batch - 1:
                final_score = torch.cat(score, dim=0)

        # fullnl instance

        final_score = final_score.unsqueeze(3).permute(0, 2, 3, 1)
        idx_value, idx_score = final_score.topk(
            k=neighbors, dim=3)  # B, G, 1, N -> B, G, 1, K'
        return idx_value, idx_score
    def forward(self, phi_indices, phi_values, phi_inverse_indices,
                phi_inverse_values, features):
        """
        Forward propagation pass.
        :param phi_indices: Sparse wavelet matrix index pairs.
        :param phi_values: Sparse wavelet matrix values.
        :param phi_inverse_indices: Inverse wavelet matrix index pairs.
        :param phi_inverse_values: Inverse wavelet matrix values.
        :param features: Feature matrix.
        :return localized_features: Filtered feature matrix extracted.
        """
        rescaled_phi_indices, rescaled_phi_values = spspmm(
            phi_indices, phi_values, self.diagonal_weight_indices,
            self.diagonal_weight_filter.view(-1), self.ncount, self.ncount,
            self.ncount)

        phi_product_indices, phi_product_values = spspmm(
            rescaled_phi_indices, rescaled_phi_values, phi_inverse_indices,
            phi_inverse_values, self.ncount, self.ncount, self.ncount)

        filtered_features = torch.mm(features, self.weight_matrix)

        localized_features = spmm(phi_product_indices, phi_product_values,
                                  self.ncount, filtered_features)

        return localized_features
示例#15
0
 def do_conv(self, x):
     #orig = x.data.clone()
     #x = convInter[0](x, egoNets[0].edge_index.to(device).data)
     #print(torch.ones((egoNets[0].edge_index.shape[1])))
     #output = torch_sparse.spmm(self.egoNets[0].edge_index.to(self.device), torch.ones((self.egoNets[0].edge_index.shape[1],)).to(self.device), self.numNodes, self.numNodes, x)
     output = torch_sparse.spmm(
         self.egoNets[0].ego_norm_ind.to(self.device),
         self.egoNets[0].ego_norm_val.to(self.device), self.numNodes,
         self.numNodes, x)
     for power in range(local_power - 1):
         #output = torch_sparse.spmm(self.egoNets[0].edge_index.to(self.device), torch.ones((self.egoNets[0].edge_index.shape[1],)).to(self.device), self.numNodes, self.numNodes, output)
         output = torch_sparse.spmm(
             self.egoNets[0].ego_norm_ind.to(self.device),
             self.egoNets[0].ego_norm_val.to(self.device), self.numNodes,
             self.numNodes, output)
     #output = convInter(x, egoNets[0].edge_index.to(device))
     for i, ego in enumerate(self.egoNets):
         if i == 0:
             continue
         #cpu_x = x.to('cpu')
         #del x
         #torch.cuda.empty_cache()
         #cur_edge_index = ego.edge_index.to(device)
         #cur_conv = convInter[i](orig, cur_edge_index)
         #del cur_edge_index
         #torch.cuda.empty_cache()
         #x = cpu_x.to(device) + cur_conv
         #del cpu_x
         #torch.cuda.empty_cache()
         #output = output + convInter(x, ego.edge_index.to(device))
         #values = ego.ego_degrees
         #values = torch.ones((ego.edge_index.shape[1]))
         #temp_out = torch_sparse.spmm(ego.edge_index.to(self.device), torch.ones((ego.edge_index.shape[1])).to(self.device), self.numNodes, self.numNodes, x)
         temp_out = torch_sparse.spmm(ego.ego_norm_ind.to(self.device),
                                      ego.ego_norm_val.to(self.device),
                                      self.numNodes, self.numNodes, x)
         for power in range(local_power - 1):
             #temp_out = torch_sparse.spmm(ego.edge_index.to(self.device), torch.ones((ego.edge_index.shape[1])).to(self.device), self.numNodes, self.numNodes, temp_out)
             temp_out = torch_sparse.spmm(ego.ego_norm_ind.to(self.device),
                                          ego.ego_norm_val.to(self.device),
                                          self.numNodes, self.numNodes,
                                          temp_out)
         output = output + temp_out
     #output = output * (1 / self.numNodes)
     output = torch.mul(output, self.norm_degrees)
     torch.cuda.empty_cache()
     return output
示例#16
0
 def reference(self, column_index, val, num_nodes):
     '''
     Compute reference SpMM (neighbor aggregation)
     result on CPU.
     '''
     print("# Compute reference on CPU")
     self.result_ref = spmm(torch.tensor(column_index,  dtype=torch.int64), \
                             torch.FloatTensor(val), num_nodes, num_nodes, self.X)
示例#17
0
    def forward(self, edge_index, edge_attr, N):
        device = edge_attr.device
        ones = torch.ones(N, 1, device=device)
        rownorm = 1. / spmm(edge_index, edge_attr, N, N, ones).view(-1)
        col = rownorm[edge_index[1]]
        edge_attr_t = col * edge_attr

        return edge_attr_t
示例#18
0
    def forward(self, x):
        """"""
        K, lap, edge_index, num_nodes = self.K, self.lap, self.edge_index, self.num_nodes
        assert (num_nodes == x.shape[1])
        # Perform filter operation recurrently.
        Tx_0 = x
        out = torch.matmul(Tx_0, self.weight[0])

        if K > 1:

            # Tx_1 = spmm(edge_index, lap, num_nodes, x)

            Tx_1 = spmm(edge_index, lap, num_nodes,
                        x.permute(1, 0, 2).reshape(num_nodes, -1))
            Tx_1 = Tx_1.reshape(num_nodes, -1,
                                self.in_channels).permute(1, 0, 2)

            # Tx_1 = batch_spmm(edge_index, lap, num_nodes, x)

            # sparse matrix multiplication is not compatible with multi-gpu, so we use dense mat mul
            # Tx_1 = torch.matmul(lap, x)

            out = out + torch.matmul(Tx_1, self.weight[1])

        for k in range(2, K):

            # Tx_2 = 2 * spmm(edge_index, lap, num_nodes, Tx_1) - Tx_0

            Tx_2 = 2 * spmm(edge_index, lap, num_nodes,
                            Tx_1.permute(1, 0, 2).reshape(num_nodes, -1))
            Tx_2 = Tx_2.reshape(num_nodes, -1, self.in_channels).permute(
                1, 0, 2) - Tx_0

            # Tx_2 = 2 * batch_spmm(edge_index, lap, num_nodes, Tx_1) - Tx_0

            # sparse matrix multiplication is not compatible with multi-gpu, so we use dense mat mul
            # Tx_2 = 2 * torch.matmul(lap, Tx_1) - Tx_0

            out = out + torch.matmul(Tx_2, self.weight[k])
            Tx_0, Tx_1 = Tx_1, Tx_2

        if self.bias is not None:
            out = out + self.bias

        return out
示例#19
0
 def forward(self, normalized_adjacency_matrix, features):
     """
     Doing a forward pass.
     :param normalized_adjacency_matrix: Normalized adjacency matrix.
     :param features: Feature matrix.
     :return base_features: Convolved features.
     """
     base_features = spmm(features["indices"], features["values"],
                          features["dimensions"][0], self.weight_matrix)
     base_features = torch.nn.functional.dropout(base_features,
                                                 p=self.dropout_rate,
                                                 training=self.training)
     base_features = torch.nn.functional.relu(base_features) + self.bias
     for iteration in range(self.iterations):
         base_features = spmm(normalized_adjacency_matrix["indices"],
                              normalized_adjacency_matrix["values"],
                              base_features.shape[0], base_features)
     return base_features
示例#20
0
def bn_f(x, layer, bn_module, subspace):
    run_mean = bn_module.running_mean
    run_var = bn_module.running_var

    #weight = torch.matmul(layer['weight'], subspace).view(layer['w_shape'])
    #bias = torch.matmul(layer['bias'], subspace).view(layer['b_shape'])
    # weight = torch.sparse.mm(layer['weight'], subspace).view(layer['w_shape'])
    # bias = torch.sparse.mm(layer['bias'], subspace).view(layer['b_shape'])
    i, v = layer['weight']._indices(), layer['weight']._values()
    weight = torch_sparse.spmm(i, v, layer['w_num'],
                               subspace).view(layer['w_shape'])
    i, v = layer['bias']._indices(), layer['bias']._values()
    bias = torch_sparse.spmm(i, v, layer['b_num'],
                             subspace).view(layer['b_shape'])
    y = F.batch_norm(x, run_mean, run_var, weight, bias, training=True)
    dummy_y = bn_module(x)

    return y
示例#21
0
def test_spmm(dtype, device):
    row = torch.tensor([0, 0, 1, 2, 2], device=device)
    col = torch.tensor([0, 2, 1, 0, 1], device=device)
    index = torch.stack([row, col], dim=0)
    value = tensor([1, 2, 4, 1, 3], dtype, device)
    x = tensor([[1, 4], [2, 5], [3, 6]], dtype, device)

    out = spmm(index, value, 3, 3, x)
    assert out.tolist() == [[7, 16], [8, 20], [7, 19]]
示例#22
0
 def forward(self, feature, adj):
     #        support = torch.spmm(feature, self.weight) # sparse
     #        output = torch.spmm(adj, support)
     support = torch.mm(feature, self.weight)  # sparse
     output = spmm(adj._indices(), adj._values(), adj.size(0), support)
     if self.bias is not None:
         return output + self.bias
     else:
         return output
示例#23
0
    def forward(self, edge_index, edge_attr, N):
        device = edge_attr.device
        ones = torch.ones(N, 1, device=device)
        rownorm = spmm(edge_index, edge_attr, N, N, ones).view(-1).pow(-0.5)
        row = rownorm[edge_index[0]]
        col = rownorm[edge_index[1]]
        edge_attr_t = row * edge_attr * col

        return edge_attr_t
示例#24
0
def test_spmm():
    row = torch.tensor([0, 0, 1, 2, 2])
    col = torch.tensor([0, 2, 1, 0, 1])
    index = torch.stack([row, col], dim=0)
    value = torch.tensor([1, 2, 4, 1, 3])

    matrix = torch.tensor([[1, 4], [2, 5], [3, 6]])
    out = spmm(index, value, 3, matrix)
    assert out.tolist() == [[7, 16], [8, 20], [7, 19]]
示例#25
0
    def forward(self, edge_index, edge_attr, N):
        device = edge_attr.device
        edge_attr_t = torch.exp(edge_attr)
        ones = torch.ones(N, 1, device=device)
        rownorm = 1. / spmm(edge_index, edge_attr_t, N, N, ones).view(-1)
        row = rownorm[edge_index[0]]
        edge_attr_t = row * edge_attr_t

        return edge_attr_t
示例#26
0
def mgunpool(x, index, values, origsize, newsize):
    # newsize - pooled size, origsize - unpooled size, P comes as nc x n
    index, values = torch_sparse.coalesce(index, values, m=origsize,
                                          n=newsize)  # P matrix
    new_feat = torch_sparse.spmm(index,
                                 values,
                                 m=origsize,
                                 n=newsize,
                                 matrix=x)  # P^T X

    return new_feat
示例#27
0
    def forward(self, features, meshes):
        batch_size, nf, edges = features.shape
        groups = [mesh.get_groups() for mesh in meshes]
        #groups = [self.pad_groups(group, edges) for group in og_groups]
        #unroll_mat = torch.cat(groups, dim=0).view(batch_size, edges, -1)
        og_occu = [mesh.get_occurrences() for mesh in meshes]
        occurrences = [self.pad_occurrences(mesh) for mesh in og_occu]
        occurrences = torch.cat(occurrences, dim=0).view(batch_size, 1, -1)
        occurrences = occurrences.expand(
            (batch_size, edges, self.unroll_target))
        #unroll_mat = unroll_mat / occurrences
        #unroll_mat = unroll_mat.to(features.device)

        groups = [self.sparse_pad_groups(mesh, edges) for mesh in groups]
        indices = torch.cat([
            torch.cat([
                torch.ones((1, g.indices().shape[-1]), dtype=torch.int64).to(
                    features.device) * idx,
                g.indices()
            ],
                      dim=0) for idx, g in enumerate(groups)
        ],
                            dim=1)
        values = torch.cat([g.values() for g in groups], dim=0)

        values = values / occurrences[indices[0, :], indices[1, :],
                                      indices[2, :]]
        #groups = torch.sparse.FloatTensor(indices, values, (batch_size, edges, self.unroll_target)).coalesce()

        #return torch.matmul(features, unroll_mat)

        if self.result is None or self.result.shape != (
                batch_size, features.shape[1], self.unroll_target):
            result = torch.zeros(
                (batch_size, features.shape[1], self.unroll_target),
                device=features.device)

        # transpose
        b, row, col = indices
        indices = torch.stack([b, col, row], dim=0)
        transposed_features = features.transpose(1, 2)

        for b_idx in range(batch_size):
            mask = indices[0, :] == b_idx
            #tmp = torch.sparse.FloatTensor(indices[1:, mask], values[mask], (self.unroll_target, edges))
            #result[b_idx, :, :] = torch.sparse.mm(tmp, transposed_features[b_idx, :, :]).T
            result[b_idx, :, :] = spmm(indices[1:, mask], values[mask],
                                       self.unroll_target, edges,
                                       transposed_features[b_idx, :, :]).T

        for mesh in meshes:
            mesh.unroll_gemm()

        return result
    def forward(self, phi_indices, phi_values, phi_inverse_indices,
                phi_inverse_values, feature_indices, feature_values, dropout):

        rescaled_phi_indices, rescaled_phi_values = spspmm(
            phi_indices, phi_values, self.diagonal_weight_indices,
            self.diagonal_weight_filter.view(-1), self.ncount, self.ncount,
            self.ncount)

        phi_product_indices, phi_product_values = spspmm(
            rescaled_phi_indices, rescaled_phi_values, phi_inverse_indices,
            phi_inverse_values, self.ncount, self.ncount, self.ncount)
        filtered_features = spmm(feature_indices, feature_values, self.ncount,
                                 self.weight_matrix)
        localized_features = spmm(phi_product_indices, phi_product_values,
                                  self.ncount, filtered_features)
        dropout_features = torch.nn.functional.dropout(
            torch.nn.functional.relu(localized_features),
            training=self.training,
            p=dropout)
        return dropout_features
示例#29
0
    def get_feat(self, graph_list):
        if self.feat_mode == 'dense':
            dense_feat = self.get_fp(graph_list)
        else:
            sp_indices, vals = self.get_fp(graph_list)
            w = self.input_linear.weight
            b = self.input_linear.bias
            dense_feat = spmm(sp_indices, vals, len(graph_list),
                              w.transpose(0, 1)) + b

        return self.mlp(dense_feat), None
 def _spmm(self, inp, params):
     ii, vv, size = params
     old_inp_size = inp.size()
     inp_flat_T = inp.view(-1, inp.size(-1)).t()
     out_flat = torch_sparse.spmm(ii,
                                  vv,
                                  m=size[0],
                                  n=size[1],
                                  matrix=inp_flat_T).t()
     out = out_flat.view(*old_inp_size)
     return out