Example #1
0
 def _to_csr(self):
     self.row_ptr, reindex = coo2csr_index(self.row, self.col, num_nodes=self.num_nodes)
     self.col = self.col[reindex]
     self.row = self.row[reindex]
     if self.weight is None:
         self.weight = torch.ones(self.row.shape[0]).to(self.row.device)
     else:
         self.weight = self.weight[reindex]
Example #2
0
 def _to_csr(self):
     self.row_ptr, reindex = coo2csr_index(self.row, self.col, num_nodes=self.num_nodes)
     self.col = self.col[reindex]
     self.row = self.row[reindex]
     for key in self.__attr_keys__():
         if key == "weight" and self[key] is None:
             self.weight = torch.ones(self.row.shape[0]).to(self.row.device)
         if self[key] is not None:
             self[key] = self[key][reindex]
Example #3
0
 def add_remaining_self_loops(self):
     edge_index, self.weight = add_remaining_self_loops(
         (self.row, self.col), num_nodes=self.num_nodes)
     self.row, self.col = edge_index
     self.row_ptr, reindex = coo2csr_index(self.row,
                                           self.col,
                                           num_nodes=self.num_nodes)
     self.row = self.row[reindex]
     self.col = self.col[reindex]
     self.attr = None
Example #4
0
    def padding_self_loops(self):
        device = self.row.device
        row, col = torch.arange(self.num_nodes, device=device), torch.arange(self.num_nodes, device=device)
        self.row = torch.cat((self.row, row))
        self.col = torch.cat((self.col, col))

        if self.weight is not None:
            values = torch.zeros(self.num_nodes, device=device) + 0.01
            self.weight = torch.cat((self.weight, values))
        if self.attr is not None:
            attr = torch.zeros(self.num_nodes, device=device)
            self.attr = torch.cat((self.attr, attr))
        self.row_ptr, reindex = coo2csr_index(self.row, self.col, num_nodes=self.num_nodes)
        self.row = self.row[reindex]
        self.col = self.col[reindex]
Example #5
0
 def add_remaining_self_loops(self):
     if self.attr is not None and len(self.attr.shape) == 1:
         edge_index, weight_attr = add_remaining_self_loops(
             (self.row, self.col), edge_weight=self.attr, fill_value=0, num_nodes=self.num_nodes
         )
         self.row, self.col = edge_index
         self.attr = weight_attr
         self.weight = torch.ones_like(self.row).float()
     else:
         edge_index, self.weight = add_remaining_self_loops(
             (self.row, self.col), fill_value=1, num_nodes=self.num_nodes
         )
         self.row, self.col = edge_index
         self.attr = None
     self.row_ptr, reindex = coo2csr_index(self.row, self.col, num_nodes=self.num_nodes)
     self.row = self.row[reindex]
     self.col = self.col[reindex]