def edge_index_from_dict(graph_dict, num_nodes=None): row, col = [], [] for key, value in graph_dict.items(): row.append(np.repeat(key, len(value))) col.append(value) _row = np.concatenate(np.array(row)) _col = np.concatenate(np.array(col)) edge_index = np.stack([_row, _col], axis=0) row_dom = edge_index[:, _row > _col] col_dom = edge_index[:, _col > _row][[1, 0]] edge_index = np.concatenate([row_dom, col_dom], axis=1) _row, _col = edge_index edge_index = np.stack([_row, _col], axis=0) order = np.lexsort((_col, _row)) edge_index = edge_index[:, order] edge_index = torch.tensor(edge_index, dtype=torch.long) # There may be duplicated edges and self loops in the datasets. edge_index, _ = remove_self_loops(edge_index) row = torch.cat([edge_index[0], edge_index[1]]) col = torch.cat([edge_index[1], edge_index[0]]) row, col, _ = coalesce(row, col) edge_index = torch.stack([row, col]) return edge_index
def forward(self, A): filter = F.softmax(self.weight, dim=1) num_channels = filter.shape[0] results = [] for i in range(num_channels): for j, (edge_index, edge_value) in enumerate(A): if j == 0: total_edge_index = edge_index total_edge_value = edge_value * filter[i][j] else: total_edge_index = torch.cat((total_edge_index, edge_index), dim=1) total_edge_value = torch.cat((total_edge_value, edge_value * filter[i][j])) row, col = total_edge_index.detach() row, col, value = coalesce(row, col, total_edge_value) index = torch.stack([row, col]) results.append((index, value)) return results