def process(self): data_list = [] for y, raw_path in enumerate(self.raw_paths): raw_path = osp.join(raw_path, os.listdir(raw_path)[0]) filenames = glob.glob(osp.join(raw_path, '*.edgelist')) for filename in filenames: with open(filename, 'r') as f: edges = f.read().split('\n')[5:-1] edge_index = [[int(s) for s in edge.split()] for edge in edges] edge_index = torch.tensor(edge_index).t().contiguous() # Remove isolated nodes, including those with only a self-loop edge_index = remove_isolated_nodes(edge_index)[0] num_nodes = int(edge_index.max()) + 1 data = Data(edge_index=edge_index, y=y, num_nodes=num_nodes) data_list.append(data) if self.pre_filter is not None: data_list = [data for data in data_list if self.pre_filter(data)] if self.pre_transform is not None: data_list = [self.pre_transform(data) for data in data_list] torch.save(self.collate(data_list), self.processed_paths[0])
def test_remove_isolated_nodes(): edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]]) out, _, mask = remove_isolated_nodes(edge_index) assert out.tolist() == [[0, 1, 0], [1, 0, 0]] assert mask.tolist() == [1, 1] out, _, mask = remove_isolated_nodes(edge_index, num_nodes=3) assert out.tolist() == [[0, 1, 0], [1, 0, 0]] assert mask.tolist() == [1, 1, 0] edge_index = torch.tensor([[0, 2, 1, 0, 2], [2, 0, 1, 0, 2]]) edge_attr = torch.tensor([1, 2, 3, 4, 5]) out1, out2, mask = remove_isolated_nodes(edge_index, edge_attr) assert out1.tolist() == [[0, 1, 0, 1], [1, 0, 0, 1]] assert out2.tolist() == [1, 2, 4, 5] assert mask.tolist() == [1, 0, 1]
def __call__(self, data): num_nodes = data.num_nodes out = remove_isolated_nodes(data.edge_index, data.edge_attr, num_nodes) data.edge_index, data.edge_attr, mask = out for key, item in data: if torch.is_tensor(item) and item.size( 0) == num_nodes and "edge" not in key: data[key] = item[mask] return data
def __call__(self, data): num_nodes = data.num_nodes out = remove_isolated_nodes(data.edge_index, data.edge_attr, num_nodes) data.edge_index, data.edge_attr, mask = out for key, item in data: if bool(re.search('edge', key)): continue if torch.is_tensor(item) and item.size(0) == num_nodes: data[key] = item[mask] return data
def __call__(cls, state_pixels): """Convert raw pixels to a graph Args: - state_pixels: (W, H, 3) np.uint8 - embedding: function taking a (W, H, 3) and returning (W, H, N) Tensor N>= 4 and always include the embedding_minimal - linker: function taking a (W, H, 3) state and returning edges_index Return: - torch_geometric Data instance """ # Pixel state state = torch.tensor(state_pixels, dtype=torch.uint8) # Nodes (original) coordinates (num_nodes, 2) num_nodes = state.size(0) * state.size(1) pos = np.stack(np.unravel_index(np.arange(num_nodes), state.shape[:2]), axis=1) # Nodes embedding (num_nodes, n_features) x = cls.embedding(state).float().reshape(num_nodes, -1) # Edges edges_index = cls.linker(state) # Remove isolated nodes edges_index, _, nodes_mask = remove_isolated_nodes(edges_index, num_nodes=num_nodes) x = x[nodes_mask] pos = torch.tensor(pos[nodes_mask], dtype=torch.long) # Action mask player_idx = utils.find_player_idx(x) # (1,) # (num_nodes, 1) bool mask = cls.get_node_neighbors_mask(player_idx, edges_index, x) # Edge features, None or (num_edges_features) edge_attr = cls.get_edge_attr(edges_index, x, pos) # Apply mask graph = Data( x=x, pos=pos, edge_index=edges_index.contiguous(), edge_attr=edge_attr, mask=mask, player_idx=player_idx, ) return graph
def save_attention(edge_index, att, labels, genes): # ------------------- Saving attention network ----------------------------------- # att = att.max(1).reshape((-1, 1)) print(edge_index.shape, labels.shape, att.shape) t = lambda t: torch.tensor(t) edge_index, att = remove_self_loops(t(edge_index.T), t(att)) edge_index, att, _ = remove_isolated_nodes(edge_index, att) edge_index = edge_index.numpy().T att = att.numpy() print(edge_index.shape, labels.shape, att.shape) nodes_edges = np.unique(edge_index.reshape((-1))) nodes_idx = np.intersect1d(np.arange(len(labels)), nodes_edges) labels = labels[nodes_idx] genes = genes[nodes_idx] meta_edges = pd.DataFrame(np.concatenate([edge_index.astype(int), att], 1)) meta_edges.to_csv(f'../data/essential_genes/gat_attention/{organism}_edges.csv', index=False) meta_nodes = pd.DataFrame(np.stack([nodes_idx, genes, labels], 1)) path = f'../data/essential_genes/gat_attention/{organism}_nodes.csv' meta_nodes.to_csv(path, index=False) print(meta_edges.shape, meta_nodes.shape) print('Saved edges and nodes to ', path)