Example #1
0
def naive_fuse_gpu(s, d, w):
    with torch.no_grad():
        y_sum = torch_sparse.matmul(s, d, reduce="sum")
        y_min = torch_sparse.matmul(s, d, reduce="min")
        y_max = torch_sparse.matmul(s, d, reduce="max")
        w = w.unsqueeze(-1)
        y = (y_sum * w[:, 0]) + (y_min * w[:, 1]) + (y_max * w[:, 2])
        torch.cuda.synchronize()
        return y
    def init_adj(self, edge_index):
        """ cache normalized adjacency and normalized strict two-hop adjacency,
        neither has self loops
        """
        n = self.num_nodes
        
        if isinstance(edge_index, SparseTensor):
            dev = adj_t.device
            adj_t = edge_index
            adj_t = scipy.sparse.csr_matrix(adj_t.to_scipy())
            adj_t[adj_t > 0] = 1
            adj_t[adj_t < 0] = 0
            adj_t = SparseTensor.from_scipy(adj_t).to(dev)
        elif isinstance(edge_index, torch.Tensor):
            row, col = edge_index
            adj_t = SparseTensor(row=col, col=row, value=None, sparse_sizes=(n, n))

        adj_t.remove_diag(0)
        adj_t2 = matmul(adj_t, adj_t)
        adj_t2.remove_diag(0)
        adj_t = scipy.sparse.csr_matrix(adj_t.to_scipy())
        adj_t2 = scipy.sparse.csr_matrix(adj_t2.to_scipy())
        adj_t2 = adj_t2 - adj_t
        adj_t2[adj_t2 > 0] = 1
        adj_t2[adj_t2 < 0] = 0

        adj_t = SparseTensor.from_scipy(adj_t)
        adj_t2 = SparseTensor.from_scipy(adj_t2)
        
        adj_t = gcn_norm(adj_t, None, n, add_self_loops=False)
        adj_t2 = gcn_norm(adj_t2, None, n, add_self_loops=False)

        self.adj_t = adj_t.to(edge_index.device)
        self.adj_t2 = adj_t2.to(edge_index.device)
Example #3
0
 def message_and_aggregate(
     self,
     adj_t: SparseTensor,
     x: _typing.Union[torch.Tensor, _typing.Tuple[torch.Tensor,
                                                  torch.Tensor]],
 ) -> torch.Tensor:
     return matmul(adj_t, x[0], reduce=self.aggr)
 def forward(self, h, h_target, adj_t):
     adj_t = adj_t.set_value(
         None, layout=None
     )  # torch_sparse.matmul will throw an error without this line
     h_n = torch_sparse.matmul(adj_t, h, reduce='mean')
     # h = self.linear(torch.cat((h_target, h_n), dim=1))  # to make my life easier, I don't concat here.
     h = self.linear_1(h_target) + self.linear_2(h_n)
     h = F.normalize(h, p=2, dim=1)
     return h
 def forward(self, x, adj_t):
     xs = [self.lins[0](x) ]
     for j in range(1,self.hops+1):
         # less runtime efficient but usually more memory efficient to mult weight matrix first
         x_j = self.lins[j](x)
         for hop in range(j):
             x_j = matmul(adj_t, x_j)
         xs += [x_j]
     return torch.cat(xs, dim=1)
    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        adj_t_2 = adj_t
        if len(self.aggregators) > 1 and 'symnorm' in self.aggregators:
            adj_t_2 = adj_t.set_value(None)

        outs = []
        for aggr in self.aggregators:
            if aggr == 'symnorm':
                out = matmul(adj_t, x, reduce='sum')
            elif aggr in ['var', 'std']:
                mean = matmul(adj_t_2, x, reduce='mean')
                mean_sq = matmul(adj_t_2, x * x, reduce='mean')
                out = mean_sq - mean * mean
                if aggr == 'std':
                    out = torch.sqrt(out.relu_() + 1e-5)
            else:
                out = matmul(adj_t_2, x, reduce=aggr)

            outs.append(out)

        return torch.stack(outs, dim=1) if len(outs) > 1 else outs[0]
Example #7
0
    def message_and_aggregate(self, edge_index, node_feature_neigh):
        r"""
        This function basically fuses the :meth:`message` and :meth:`aggregate` into 
        one function. It will save memory and avoid message materialization. More 
        information please refer to the PyTorch Geometric documentation.

        Args:
            edge_index (:class:`torch_sparse.SparseTensor`): The `edge_index` sparse tensor.
            node_feature_neigh (:class:`torch.Tensor`): Neighbor feature tensor.
        """
        out = matmul(edge_index, node_feature_neigh, reduce="mean")
        return out
Example #8
0
    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        # aggregation function is sum
        #out = torch_scatter.scatter(src=x, index=adj_t, dim_size=x.size(0), reduce='sum')
        # step 1: apply gamma, beta to curr_emb; and then get activation function
        # activation function is LeakyReLU
        neigh_emb = matmul(adj_t, x, reduce=self.aggr)
        #print(neigh_emb.shape)
        w_emb = self.edge_type_to_film_computation_layers(neigh_emb)
        #print(w_emb.shape)

        leaky_relu = torch.nn.LeakyReLU()
        out = leaky_relu(w_emb)
        #print(out.shape)

        return out
Example #9
0
    def propagate(self, edge_index, size, x, h, edge_weight):

        # message and aggregate
        if size is None:
            size = [x.size(0), x.size(0)]

        adj = torch_sparse.SparseTensor(row=edge_index[0],
                                        rowptr=None,
                                        col=edge_index[1],
                                        value=edge_weight,
                                        sparse_sizes=torch.Size(size),
                                        is_sorted=True)  # is_sorted=True
        out = torch_sparse.matmul(adj, h, reduce='sum')
        # out = torch.cat([out, self.lin_root(x)], dim=1)
        out = out + self.lin_root(x)
        return out
    def forward(self, data, train_idx):
        n = data.graph['num_nodes']
        edge_index = data.graph['edge_index']
        edge_weight=None

        if isinstance(edge_index, torch.Tensor):
            edge_index, edge_weight = gcn_norm( 
                edge_index, edge_weight, n, False)
            row, col = edge_index
            # transposed if directed
            adj_t = SparseTensor(row=col, col=row, value=edge_weight, sparse_sizes=(n, n))
        elif isinstance(edge_index, SparseTensor):
            edge_index = gcn_norm(
                edge_index, edge_weight, n, False)
            edge_weight=None
            adj_t = edge_index

        y = torch.zeros((n, self.out_channels)).to(adj_t.device())
        if data.label.shape[1] == 1:
            # make one hot
            y[train_idx] = F.one_hot(data.label[train_idx], self.out_channels).squeeze(1).to(y)
        elif self.mult_bin:
            y = torch.zeros((n, 2*self.out_channels)).to(adj_t.device())
            for task in range(data.label.shape[1]):
                y[train_idx, 2*task:2*task+2] = F.one_hot(data.label[train_idx, task], 2).to(y)
        else:
            y[train_idx] = data.label[train_idx].to(y.dtype)
        result = y.clone()
        for _ in range(self.num_iters):
            for _ in range(self.hops):
                result = matmul(adj_t, result)
            result *= self.alpha
            result += (1-self.alpha)*y

        if self.mult_bin:
            output = torch.zeros((n, self.out_channels)).to(result.device)
            for task in range(data.label.shape[1]):
                output[:, task] = result[:, 2*task+1]
            result = output

        return result
Example #11
0
def bench_spmm(g, ctx, binary_op, reduce_op):
    assert binary_op == 'copy_u'
    adj_t = g[1].to(ctx)
    ptr = adj_t.storage.rowptr().to(ctx)
    row, col = g[0]
    row = row.to(ctx)
    col = col.to(ctx)
    print("SPMM\n----------------------------")
    with th.no_grad():
        for n_hid in [1, 2, 4, 8, 16, 32, 64, 128]:
            try:
                nfeat = th.rand(adj_t.size(1), n_hid, device=ctx)
                efeat = None
                accum_time = 0
                for n_times in range(10):
                    with th_op_time() as timer:
                        out = matmul(adj_t, nfeat, reduce=reduce_op)
                    if n_times >= n_cold_start:
                        accum_time += timer.time
                avg_time = accum_time / (n_times - n_cold_start)
                print('hidden size: {}, avg time: {}'.format(n_hid, avg_time))
            except:
                print('hidden size: {}, OOM'.format(n_hid))
    def forward(self, data):
        edge_index = data.graph['edge_index']
        x = data.graph['node_feat']
        x = self.lin(x)
        n = data.graph['num_nodes']
        edge_weight=None

        if isinstance(edge_index, torch.Tensor):
            edge_index, edge_weight = gcn_norm( 
                edge_index, edge_weight, n, False,
                 dtype=x.dtype)
            row, col = edge_index
            adj_t = SparseTensor(row=col, col=row, value=edge_weight, sparse_sizes=(n, n))
        elif isinstance(edge_index, SparseTensor):
            edge_index = gcn_norm(
                edge_index, edge_weight, n, False,
                dtype=x.dtype)
            edge_weight=None
            adj_t = edge_index

        for _ in range(self.hops):
            x = matmul(adj_t, x)
        
        return x
Example #13
0
 def message_and_aggregate(self, adj_t, x):
     """
     Sparse matrix multiplication
     """
     return matmul(adj_t, x, reduce=self.aggr)
Example #14
0
 def message_and_aggregate(self, adj_t, x):
     return torch_sparse.matmul(adj_t, x, reduce=self.aggr)
 def forward(self, h, adj_t):
     h_n = torch_sparse.matmul(adj_t, h, reduce='mean')
     # h = self.linear(torch.cat((h, h_n), dim=1))  # to make my life easier, I don't concat here.
     h = self.linear_1(h) + self.linear_2(h_n)
     h = F.normalize(h, p=2, dim=1)
     return h
Example #16
0
 def forward(self, x, adj_t):
     x = x.matmul(self.weight)
     return torch_sparse.matmul(adj_t, x) + self.bias
Example #17
0
def make_multihop_edges(data, k):
    """
    Adds edges corresponding to distances up to k to a data object.
    :param data: torch_geometric.data object, in coo format
    (ie an edge (i, j) with label v is stored with an arbitrary index u as:
     edge_index[0, u] = i, edge_index[1, u]=j, edge_attr[u]=v)
    :return: a new data object with new fields, multihop_edge_index and distance.
    distance[u] contains values from 1 to k corresponding to the distance between
    multihop_edge_index[0, u] and multihop_edge_index[1, u]
    """
    data = new(data)

    N = data.num_nodes
    E = data.num_edges
    if E == 0:
        data.multihop_edge_index = torch.empty_like(data.edge_index)
        data.distance = torch.empty_like(data.multihop_edge_index[0])
        return data

    # Get the distance 0
    multihop_edge_index = torch.arange(0,
                                       N,
                                       dtype=data.edge_index[0].dtype,
                                       device=data.x.device)
    distance = torch.zeros_like(multihop_edge_index)
    multihop_edge_index = multihop_edge_index.unsqueeze(0).repeat(2, 1)

    # Get the distance 1
    multihop_edge_index = torch.cat((multihop_edge_index, data.edge_index),
                                    dim=1)
    distance = torch.cat((distance, torch.ones_like(data.edge_index[0])),
                         dim=0)

    A = SparseTensor(row=data.edge_index[0],
                     col=data.edge_index[1],
                     value=torch.ones_like(data.edge_index[0]).float(),
                     sparse_sizes=(N, N),
                     is_sorted=False)
    Ad = A  # A to the power d

    # Get larger distances
    for d in range(2, k + 1):
        Ad = torch_sparse.matmul(Ad, A)
        row, col, v = Ad.coo()
        d_edge_index = torch.stack((row, col))
        d_edge_attr = torch.empty_like(row).fill_(d)
        multihop_edge_index = torch.cat((multihop_edge_index, d_edge_index),
                                        dim=1)
        distance = torch.cat((distance, d_edge_attr), dim=0)

    # remove dupicate, keep only shortest distance
    multihop_edge_index, distance = coalesce(multihop_edge_index,
                                             distance,
                                             N,
                                             N,
                                             op='min')

    data.multihop_edge_index = multihop_edge_index
    data.distance = distance

    return data
Example #18
0
 def message_and_aggregate(self, adj_t, x):
     if self.aggr_fun in ["var", "std"]:
         # TODO
         raise NotImplementedError
     return matmul(adj_t, x, reduce=self.aggr)
Example #19
0
 def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
     return matmul(adj_t, x, reduce=self.aggr)
Example #20
0
 def message_and_aggregate(self, adj_t, x):
     adj_t = adj_t.set_value(None, layout=None)
     return matmul(adj_t, x[0], reduce=self.aggr)
Example #21
0
 def message_and_aggregate(self, adj_t: SparseTensor,
                           x: OptPairTensor) -> Tensor:
     adj_t = adj_t.set_value(None, layout=None)
     return matmul(adj_t, x[0], reduce=self.aggr)
Example #22
0
 def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
     adj_t = adj_t.set_value(None)
     return matmul(adj_t, x, reduce=self.aggr)
Example #23
0
def csr_dmm_gpu(s, d):
    with torch.no_grad():
        y = torch_sparse.matmul(s, d, reduce="sum")
        torch.cuda.synchronize()
        return y
 def forward(self, x, adj_t, adj_t2):
     x1 = matmul(adj_t, x)
     x2 = matmul(adj_t2, x)
     return torch.cat([x1, x2], dim=1)
Example #25
0
 def message_and_aggregate(self, adj_t, x):
     return matmul(adj_t, x, reduce=self.aggr)
Example #26
0
 def message_and_aggregate(self, adj_t, x):
     # print("message and aggregate")
     return matmul(adj_t, x, reduce=self.aggr)