def process(self):
        data_list = []

        for y, raw_path in enumerate(self.raw_paths):
            raw_path = osp.join(raw_path, os.listdir(raw_path)[0])
            filenames = glob.glob(osp.join(raw_path, '*.edgelist'))

            for filename in filenames:
                with open(filename, 'r') as f:
                    edges = f.read().split('\n')[5:-1]
                edge_index = [[int(s) for s in edge.split()] for edge in edges]
                edge_index = torch.tensor(edge_index).t().contiguous()
                # Remove isolated nodes, including those with only a self-loop
                edge_index = remove_isolated_nodes(edge_index)[0]
                num_nodes = int(edge_index.max()) + 1
                data = Data(edge_index=edge_index, y=y, num_nodes=num_nodes)
                data_list.append(data)

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        torch.save(self.collate(data_list), self.processed_paths[0])
Beispiel #2
0
def test_remove_isolated_nodes():
    edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]])

    out, _, mask = remove_isolated_nodes(edge_index)
    assert out.tolist() == [[0, 1, 0], [1, 0, 0]]
    assert mask.tolist() == [1, 1]

    out, _, mask = remove_isolated_nodes(edge_index, num_nodes=3)
    assert out.tolist() == [[0, 1, 0], [1, 0, 0]]
    assert mask.tolist() == [1, 1, 0]

    edge_index = torch.tensor([[0, 2, 1, 0, 2], [2, 0, 1, 0, 2]])
    edge_attr = torch.tensor([1, 2, 3, 4, 5])
    out1, out2, mask = remove_isolated_nodes(edge_index, edge_attr)
    assert out1.tolist() == [[0, 1, 0, 1], [1, 0, 0, 1]]
    assert out2.tolist() == [1, 2, 4, 5]
    assert mask.tolist() == [1, 0, 1]
Beispiel #3
0
    def __call__(self, data):
        num_nodes = data.num_nodes
        out = remove_isolated_nodes(data.edge_index, data.edge_attr, num_nodes)
        data.edge_index, data.edge_attr, mask = out

        for key, item in data:
            if torch.is_tensor(item) and item.size(
                    0) == num_nodes and "edge" not in key:
                data[key] = item[mask]

        return data
Beispiel #4
0
    def __call__(self, data):
        num_nodes = data.num_nodes
        out = remove_isolated_nodes(data.edge_index, data.edge_attr, num_nodes)
        data.edge_index, data.edge_attr, mask = out

        for key, item in data:
            if bool(re.search('edge', key)):
                continue
            if torch.is_tensor(item) and item.size(0) == num_nodes:
                data[key] = item[mask]

        return data
    def __call__(cls, state_pixels):
        """Convert raw pixels to a graph

        Args:
            - state_pixels: (W, H, 3) np.uint8
            - embedding: function taking a (W, H, 3) and returning (W, H, N) Tensor
            N>= 4 and always include the embedding_minimal
            - linker: function taking a (W, H, 3) state and returning edges_index
        Return:
            - torch_geometric Data instance
        """
        # Pixel state
        state = torch.tensor(state_pixels, dtype=torch.uint8)

        # Nodes (original) coordinates (num_nodes, 2)
        num_nodes = state.size(0) * state.size(1)
        pos = np.stack(np.unravel_index(np.arange(num_nodes), state.shape[:2]),
                       axis=1)

        # Nodes embedding (num_nodes, n_features)
        x = cls.embedding(state).float().reshape(num_nodes, -1)

        # Edges
        edges_index = cls.linker(state)

        # Remove isolated nodes
        edges_index, _, nodes_mask = remove_isolated_nodes(edges_index,
                                                           num_nodes=num_nodes)
        x = x[nodes_mask]
        pos = torch.tensor(pos[nodes_mask], dtype=torch.long)

        # Action mask
        player_idx = utils.find_player_idx(x)  # (1,)
        # (num_nodes, 1) bool
        mask = cls.get_node_neighbors_mask(player_idx, edges_index, x)

        # Edge features, None or (num_edges_features)
        edge_attr = cls.get_edge_attr(edges_index, x, pos)

        # Apply mask
        graph = Data(
            x=x,
            pos=pos,
            edge_index=edges_index.contiguous(),
            edge_attr=edge_attr,
            mask=mask,
            player_idx=player_idx,
        )

        return graph
Beispiel #6
0
def save_attention(edge_index, att, labels, genes):
    # ------------------- Saving attention network ----------------------------------- #    
    att = att.max(1).reshape((-1, 1))
    print(edge_index.shape, labels.shape, att.shape)
    t = lambda t: torch.tensor(t)
    edge_index, att = remove_self_loops(t(edge_index.T), t(att))
    edge_index, att, _ = remove_isolated_nodes(edge_index, att)
    edge_index = edge_index.numpy().T
    att = att.numpy()
    print(edge_index.shape, labels.shape, att.shape)
    nodes_edges = np.unique(edge_index.reshape((-1)))
    nodes_idx = np.intersect1d(np.arange(len(labels)), nodes_edges)

    labels = labels[nodes_idx]
    genes = genes[nodes_idx]

    meta_edges = pd.DataFrame(np.concatenate([edge_index.astype(int), att], 1))
    meta_edges.to_csv(f'../data/essential_genes/gat_attention/{organism}_edges.csv', index=False)
    
    meta_nodes = pd.DataFrame(np.stack([nodes_idx, genes, labels], 1))
    path = f'../data/essential_genes/gat_attention/{organism}_nodes.csv'
    meta_nodes.to_csv(path, index=False)
    print(meta_edges.shape, meta_nodes.shape)
    print('Saved edges and nodes to ', path)