Esempio n. 1
0
def test_graph_store():
    graph_store = MyGraphStore()
    edge_index = torch.LongTensor([[0, 1], [1, 2]])
    adj = SparseTensor(row=edge_index[0], col=edge_index[1])

    def assert_equal_tensor_tuple(expected, actual):
        assert len(expected) == len(actual)
        for i in range(len(expected)):
            assert torch.equal(expected[i], actual[i])

    # We put all three tensor types: COO, CSR, and CSC, and we get them back
    # to confirm that `GraphStore` works as intended.
    coo = adj.coo()[:-1]
    csr = adj.csr()[:-1]
    csc = adj.csc()[-2::-1]  # (row, colptr)

    # Put:
    graph_store['edge', EdgeLayout.COO] = coo
    graph_store['edge', 'csr'] = csr
    graph_store['edge', 'csc'] = csc

    # Get:
    assert_equal_tensor_tuple(coo, graph_store['edge', 'coo'])
    assert_equal_tensor_tuple(csr, graph_store['edge', 'csr'])
    assert_equal_tensor_tuple(csc, graph_store['edge', 'csc'])

    # Get attrs:
    edge_attrs = graph_store.get_all_edge_attrs()
    assert len(edge_attrs) == 3

    with pytest.raises(KeyError):
        _ = graph_store['edge_2', 'coo']
Esempio n. 2
0
    def forward(self, x: Tensor, M: SparseTensor, batch: OptTensor = None):
        """"""
        if batch is None:
            batch = x.new_zeros(x.size(0), dtype=torch.long)

        row, col, edge_weight = M.coo()

        score1 = (x * self.p).sum(dim=-1)
        score2 = scatter_add(edge_weight, col, dim=0, dim_size=x.size(0))
        score = self.beta[0] * score1 + self.beta[1] * score2

        if self.min_score is None:
            score = self.nonlinearity(score)
        else:
            score = softmax(score, batch)

        perm = topk(score, self.ratio, batch, self.min_score)
        x = x[perm] * score[perm].view(-1, 1)
        x = self.multiplier * x if self.multiplier != 1 else x

        edge_index = torch.stack([col, row], dim=0)
        edge_index, edge_attr = filter_adj(edge_index, edge_weight, perm,
                                           num_nodes=score.size(0))

        return x, edge_index, edge_attr, batch[perm], perm, score[perm]
Esempio n. 3
0
    def forward(self, x, edge_index, edge_weight=None, batch=None):
        """"""
        N = x.size(0)

        edge_index, edge_weight = add_remaining_self_loops(edge_index,
                                                           edge_weight,
                                                           fill_value=1,
                                                           num_nodes=N)

        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        x = x.unsqueeze(-1) if x.dim() == 1 else x

        x_pool = x
        if self.GNN is not None:
            x_pool = self.gnn_intra_cluster(x=x,
                                            edge_index=edge_index,
                                            edge_weight=edge_weight)

        x_pool_j = x_pool[edge_index[0]]
        x_q = scatter(x_pool_j, edge_index[1], dim=0, reduce='max')
        x_q = self.lin(x_q)[edge_index[1]]

        score = self.att(torch.cat([x_q, x_pool_j], dim=-1)).view(-1)
        score = F.leaky_relu(score, self.negative_slope)
        score = softmax(score, edge_index[1], num_nodes=N)

        # Sample attention coefficients stochastically.
        score = F.dropout(score, p=self.dropout, training=self.training)

        v_j = x[edge_index[0]] * score.view(-1, 1)
        x = scatter(v_j, edge_index[1], dim=0, reduce='add')

        # Cluster selection.
        fitness = self.gnn_score(x, edge_index).sigmoid().view(-1)
        perm = topk(fitness, self.ratio, batch)
        x = x[perm] * fitness[perm].view(-1, 1)
        batch = batch[perm]

        # Graph coarsening.
        row, col = edge_index
        A = SparseTensor(row=row,
                         col=col,
                         value=edge_weight,
                         sparse_sizes=(N, N))
        S = SparseTensor(row=row, col=col, value=score, sparse_sizes=(N, N))
        S = S[:, perm]

        A = S.t() @ A @ S

        if self.add_self_loops:
            A = A.fill_diag(1.)
        else:
            A = A.remove_diag()

        row, col, edge_weight = A.coo()
        edge_index = torch.stack([row, col], dim=0)

        return x, edge_index, edge_weight, batch, perm
Esempio n. 4
0
def graclus_coarsen(A: SparseTensor, level: int):
    row, col, wgt = A.coo()
    coarsen_cluster = []
    for i in range(level):
        cluster = graclus_cluster(row, col, wgt)
        _, cluster = cluster.unique(return_inverse=True)
        (row, col), wgt = pool_edge(cluster, torch.stack([row, col]), wgt)
        coarsen_cluster.append(cluster.cpu().numpy())
    return row, col, wgt, coarsen_cluster
Esempio n. 5
0
    def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
        """"""
        if x.dim() > 1:
            assert (x.sum(dim=-1) == 1).sum() == x.size(0)
            x = x.argmax(dim=-1)  # one-hot -> integer.
        assert x.dtype == torch.long

        adj_t = edge_index
        if not isinstance(adj_t, SparseTensor):
            adj_t = SparseTensor(row=edge_index[1],
                                 col=edge_index[0],
                                 sparse_sizes=(x.size(0), x.size(0)))

        out = []
        _, col, _ = adj_t.coo()
        deg = adj_t.storage.rowcount().tolist()
        for node, neighbors in zip(x.tolist(), x[col].split(deg)):
            idx = hash(tuple([node] + neighbors.sort()[0].tolist()))
            if idx not in self.hashmap:
                self.hashmap[idx] = len(self.hashmap)
            out.append(self.hashmap[idx])

        return torch.tensor(out, device=x.device)
Esempio n. 6
0
    def __dropout_adj__(self, sparse_adj: SparseTensor,
                        dropout_adj_prob: float):
        # number of nodes
        N = sparse_adj.size(0)
        # sparse adj matrix to dense adj matrix
        row, col, edge_attr = sparse_adj.coo()
        edge_index = torch.stack([row, col], dim=0)
        # dropout adjacency matrix -> generalization
        edge_index, edge_attr = dropout_adj(edge_index,
                                            edge_attr=edge_attr,
                                            p=dropout_adj_prob,
                                            force_undirected=True,
                                            training=self.training)
        # because dropout removes self-loops (due to force_undirected=True), make sure to add them back again
        edge_index, edge_attr = add_remaining_self_loops(edge_index,
                                                         edge_weight=edge_attr,
                                                         fill_value=0.00,
                                                         num_nodes=N)
        # dense adj matrix to sparse adj matrix
        sparse_adj = SparseTensor.from_edge_index(edge_index,
                                                  edge_attr=edge_attr,
                                                  sparse_sizes=(N, N))

        return sparse_adj
Esempio n. 7
0
class GraphSAINTSampler(object):
    r"""The GraphSAINT sampler base class from the `"GraphSAINT: Graph
    Sampling Based Inductive Learning Method"
    <https://arxiv.org/abs/1907.04931>`_ paper.
    Given a graph in a :obj:`data` object, this class samples nodes and
    constructs subgraphs that can be processed in a mini-batch fashion.
    Normalization coefficients for each mini-batch are given via
    :obj:`node_norm` and :obj:`edge_norm` data attributes.

    .. note::

        See :class:`torch_geometric.data.GraphSAINTNodeSampler`,
        :class:`torch_geometric.data.GraphSAINTEdgeSampler` and
        :class:`torch_geometric.data.GraphSAINTRandomWalkSampler` for
        currently supported samplers.
        For an example of using GraphSAINT sampling, see
        `examples/graph_saint.py <https://github.com/rusty1s/pytorch_geometric/
        blob/master/examples/graph_saint.py>`_.

    Args:
        data (torch_geometric.data.Data): The graph data object.
        batch_size (int): The approximate number of samples per batch to load.
        num_steps (int, optional): The number of iterations.
            (default: :obj:`1`)
        sample_coverage (int): How many samples per node should be used to
            compute normalization statistics. (default: :obj:`50`)
        save_dir (string, optional): If set, will save normalization
            statistics to the :obj:`save_dir` directory for faster re-use.
            (default: :obj:`None`)
        num_workers (int, optional): How many subprocesses to use for data
            sampling.
            :obj:`0` means that the data will be sampled in the main process.
            (default: :obj:`0`)
    """

    def __init__(self, data, batch_size, num_steps=1, sample_coverage=50,
                 save_dir=None, num_workers=0):
        assert data.edge_index is not None
        assert 'node_norm' not in data
        assert 'edge_norm' not in data

        self.N = N = data.num_nodes
        self.E = data.num_edges

        self.adj = SparseTensor(row=data.edge_index[0], col=data.edge_index[1],
                                value=data.edge_attr, sparse_sizes=(N, N))

        self.data = copy.copy(data)
        self.data.edge_index = None
        self.data.edge_attr = None

        self.batch_size = batch_size
        self.num_steps = num_steps
        self.sample_coverage = sample_coverage
        self.num_workers = num_workers
        self.__count__ = 0

        if self.num_workers > 0:
            self.__sample_queue__ = Queue()
            self.__sample_workers__ = []
            for _ in range(self.num_workers):
                worker = Process(target=self.__put_sample__,
                                 args=(self.__sample_queue__, ))
                worker.daemon = True
                worker.start()
                self.__sample_workers__.append(worker)

        path = osp.join(save_dir or '', self.__filename__)
        if save_dir is not None and osp.exists(path):
            self.node_norm, self.edge_norm = torch.load(path)
        else:
            self.node_norm, self.edge_norm = self.__compute_norm__()
            if save_dir is not None:
                torch.save((self.node_norm, self.edge_norm), path)

        if self.num_workers > 0:
            self.__data_queue__ = Queue()
            self.__data_workers__ = []
            for _ in range(self.num_workers):
                worker = Process(target=self.__put_data__,
                                 args=(self.__data_queue__, ))
                worker.daemon = True
                worker.start()
                self.__data_workers__.append(worker)

    @property
    def __filename__(self):
        return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt'

    def __sample_nodes__(self, num_examples):
        raise NotImplementedError

    def __sample__(self, num_examples):
        node_samples = self.__sample_nodes__(num_examples)

        samples = []
        for node_idx in node_samples:
            node_idx = node_idx.unique()
            adj, edge_idx = self.adj.saint_subgraph(node_idx)
            samples.append((node_idx, edge_idx, adj))
        return samples

    def __compute_norm__(self):
        node_count = torch.zeros(self.N, dtype=torch.float)
        edge_count = torch.zeros(self.E, dtype=torch.float)

        num_samples = total_sampled_nodes = 0
        pbar = tqdm(total=self.N * self.sample_coverage)
        pbar.set_description('Compute GraphSAINT normalization')
        while total_sampled_nodes < self.N * self.sample_coverage:
            num_sampled_nodes = 0
            if self.num_workers > 0:
                for _ in range(200):
                    node_idx, edge_idx, _ = self.__sample_queue__.get()
                    node_count[node_idx] += 1
                    edge_count[edge_idx] += 1
                    num_sampled_nodes += node_idx.size(0)
            else:
                samples = self.__sample__(200)
                for node_idx, edge_idx, _ in samples:
                    node_count[node_idx] += 1
                    edge_count[edge_idx] += 1
                    num_sampled_nodes += node_idx.size(0)
            pbar.update(num_sampled_nodes)
            total_sampled_nodes += num_sampled_nodes
            num_samples += 200
        pbar.close()

        row, col, _ = self.adj.coo()

        edge_norm = (node_count[col] / edge_count).clamp_(0, 1e4)
        edge_norm[torch.isnan(edge_norm)] = 0.1

        node_count[node_count == 0] = 0.1
        node_norm = num_samples / node_count / self.N

        return node_norm, edge_norm

    def __get_data_from_sample__(self, sample):
        node_idx, edge_idx, adj = sample

        data = self.data.__class__()
        data.num_nodes = node_idx.size(0)
        row, col, value = adj.coo()
        data.edge_index = torch.stack([row, col], dim=0)
        data.edge_attr = value

        for key, item in self.data:
            if item.size(0) == self.N:
                data[key] = item[node_idx]
            elif item.size(0) == self.E:
                data[key] = item[edge_idx]
            else:
                data[key] = item

        data.node_norm = self.node_norm[node_idx]
        data.edge_norm = self.edge_norm[edge_idx]

        return data

    def __put_sample__(self, queue):
        while True:
            sample = self.__sample__(1)[0]
            queue.put(sample)

    def __put_data__(self, queue):
        while True:
            sample = self.__sample_queue__.get()
            data = self.__get_data_from_sample__(sample)
            queue.put(data)

    def __next__(self):
        if self.__count__ < len(self):
            self.__count__ += 1
            if self.num_workers > 0:
                data = self.__data_queue__.get()
            else:
                sample = self.__sample__(1)[0]
                data = self.__get_data_from_sample__(sample)
            return data
        else:
            raise StopIteration

    def __len__(self):
        return self.num_steps

    def __iter__(self):
        self.__count__ = 0
        return self
class GraphSAINTSampler(torch.utils.data.DataLoader):
    r"""The GraphSAINT sampler base class from the `"GraphSAINT: Graph
    Sampling Based Inductive Learning Method"
    <https://arxiv.org/abs/1907.04931>`_ paper.
    Given a graph in a :obj:`data` object, this class samples nodes and
    constructs subgraphs that can be processed in a mini-batch fashion.
    Normalization coefficients for each mini-batch are given via
    :obj:`node_norm` and :obj:`edge_norm` data attributes.

    .. note::

        See :class:`torch_geometric.data.GraphSAINTNodeSampler`,
        :class:`torch_geometric.data.GraphSAINTEdgeSampler` and
        :class:`torch_geometric.data.GraphSAINTRandomWalkSampler` for
        currently supported samplers.
        For an example of using GraphSAINT sampling, see
        `examples/graph_saint.py <https://github.com/rusty1s/pytorch_geometric/
        blob/master/examples/graph_saint.py>`_.

    Args:
        data (torch_geometric.data.Data): The graph data object.
        batch_size (int): The approximate number of samples per batch.
        num_steps (int, optional): The number of iterations per epoch.
            (default: :obj:`1`)
        sample_coverage (int): How many samples per node should be used to
            compute normalization statistics. (default: :obj:`0`)
        save_dir (string, optional): If set, will save normalization
            statistics to the :obj:`save_dir` directory for faster re-use.
            (default: :obj:`None`)
        log (bool, optional): If set to :obj:`False`, will not log any
            pre-processing progress. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size` or
            :obj:`num_workers`.
    """
    def __init__(self,
                 data,
                 batch_size: int,
                 num_steps: int = 1,
                 sample_coverage: int = 0,
                 save_dir: Optional[str] = None,
                 log: bool = True,
                 **kwargs):

        assert data.edge_index is not None
        assert 'node_norm' not in data
        assert 'edge_norm' not in data

        self.num_steps = num_steps
        self.__batch_size__ = batch_size
        self.sample_coverage = sample_coverage
        self.log = log

        self.N = N = data.num_nodes
        self.E = data.num_edges

        self.adj = SparseTensor(row=data.edge_index[0],
                                col=data.edge_index[1],
                                value=torch.arange(
                                    self.E, device=data.edge_index.device),
                                sparse_sizes=(N, N))

        self.data = copy.copy(data)
        self.data.edge_index = None

        super(GraphSAINTSampler, self).__init__(self,
                                                batch_size=1,
                                                collate_fn=self.__collate__,
                                                **kwargs)

        if self.sample_coverage > 0:
            path = osp.join(save_dir or '', self.__filename__)
            if save_dir is not None and osp.exists(path):  # pragma: no cover
                self.node_norm, self.edge_norm = torch.load(path)
            else:
                self.node_norm, self.edge_norm = self.__compute_norm__()
                if save_dir is not None:  # pragma: no cover
                    torch.save((self.node_norm, self.edge_norm), path)

    @property
    def __filename__(self):
        return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt'

    def __len__(self):
        return self.num_steps

    def __sample_nodes__(self, batch_size):
        raise NotImplementedError

    def __getitem__(self, idx):
        node_idx = self.__sample_nodes__(self.__batch_size__).unique()
        adj, _ = self.adj.saint_subgraph(node_idx)
        return node_idx, adj

    def __collate__(self, data_list):
        assert len(data_list) == 1
        node_idx, adj = data_list[0]

        data = self.data.__class__()
        data.num_nodes = node_idx.size(0)
        row, col, edge_idx = adj.coo()
        data.edge_index = torch.stack([row, col], dim=0)

        for key, item in self.data:
            if isinstance(item, torch.Tensor) and item.size(0) == self.N:
                data[key] = item[node_idx]
            elif isinstance(item, torch.Tensor) and item.size(0) == self.E:
                data[key] = item[edge_idx]
            else:
                data[key] = item

        if self.sample_coverage > 0:
            data.node_norm = self.node_norm[node_idx]
            data.edge_norm = self.edge_norm[edge_idx]

        return data

    def __compute_norm__(self):
        node_count = torch.zeros(self.N, dtype=torch.float)
        edge_count = torch.zeros(self.E, dtype=torch.float)

        loader = torch.utils.data.DataLoader(self,
                                             batch_size=200,
                                             collate_fn=lambda x: x,
                                             num_workers=self.num_workers)

        if self.log:  # pragma: no cover
            pbar = tqdm(total=self.N * self.sample_coverage)
            pbar.set_description('Compute GraphSAINT normalization')

        num_samples = total_sampled_nodes = 0
        while total_sampled_nodes < self.N * self.sample_coverage:
            for data in loader:
                for node_idx, adj in data:
                    edge_idx = adj.storage.value()
                    node_count[node_idx] += 1
                    edge_count[edge_idx] += 1
                    total_sampled_nodes += node_idx.size(0)

                    if self.log:  # pragma: no cover
                        pbar.update(node_idx.size(0))
            num_samples += self.num_steps

        if self.log:  # pragma: no cover
            pbar.close()

        row, _, edge_idx = self.adj.coo()
        t = torch.empty_like(edge_count).scatter_(0, edge_idx, node_count[row])
        edge_norm = (t / edge_count).clamp_(0, 1e4)
        edge_norm[torch.isnan(edge_norm)] = 0.1

        node_count[node_count == 0] = 0.1
        node_norm = num_samples / node_count / self.N

        return node_norm, edge_norm
Esempio n. 9
0
 def sample_adj(self, adj: SparseTensor) -> SparseTensor:
     row, col, _ = adj.coo()
     deg = degree(row, num_nodes=adj.size(0))
     prob = (self.max_sample * (1. / deg))[row]
     mask = torch.rand_like(prob) < prob
     return adj.masked_select_nnz(mask, layout='coo')
Esempio n. 10
0
class MySAINTSampler(object):
    r"""A new random-walk sampler for GraphSAINT that samples initial nodes
    by iterating over node permutations. The benefit is that we can leverage
    this sampler for subgraph-based inference.

    Args:
        data (torch_geometric.data.Data): The graph data object.
        batch_size (int): The number of walks to sample per batch.
        walk_length (int): The length of each random walk.
        sample_coverage (int): How many samples per node should be used to
            compute normalization statistics. (default: :obj:`50`)
        save_dir (string, optional): If set, will save normalization
            statistics to the :obj:`save_dir` directory for faster re-use.
            (default: :obj:`None`)
        log (bool, optional): If set to :obj:`False`, will not log any
            progress. (default: :obj:`True`)
    """

    def __init__(self, data, batch_size, sample_type='random_walk', walk_length=2, sample_coverage=50,
                 save_dir=None, log=True):
        assert data.edge_index is not None
        assert 'node_norm' not in data
        assert 'edge_norm' not in data
        assert sample_type in ('node', 'random_walk')

        self.N = N = data.num_nodes
        self.E = data.num_edges

        self.adj = SparseTensor(row=data.edge_index[0], col=data.edge_index[1],
                                value=data.edge_attr, sparse_sizes=(N, N))

        self.data = copy.copy(data)
        self.data.edge_index = None
        self.data.edge_attr = None

        self.sample_type = sample_type
        self.batch_size = batch_size
        # self.num_steps = num_steps
        self.walk_length = walk_length
        self.sample_coverage = sample_coverage
        self.log = log

        path = osp.join(save_dir or '', self.__filename__)
        if save_dir is not None and osp.exists(path):  # pragma: no cover
            self.node_norm, self.edge_norm = torch.load(path)
        else:
            self.node_norm, self.edge_norm = self.__compute_norm__()
            if save_dir is not None:  # pragma: no cover
                torch.save((self.node_norm, self.edge_norm), path)

    @property
    def __filename__(self):
        return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt'

    def __sample_nodes__(self):
        """Sampling initial nodes by iterating over the random permutation of nodes"""
        tmp_map = torch.arange(self.N, dtype=torch.long)
        all_n_id = torch.randperm(self.N, dtype=torch.long)
        node_samples = []
        for s_id in range(0, self.N, self.batch_size):
            init_n_id = all_n_id[s_id:s_id+self.batch_size]  # [batch_size]

            if self.sample_type == 'random_walk':
                n_id = self.adj.random_walk(init_n_id, self.walk_length)  # [batch_size, walk_length+1]
                n_id = n_id.flatten().unique()  # [num_nodes_in_subgraph]
                tmp_map[n_id] = torch.arange(n_id.size(0), dtype=torch.long)
                res_n_id = tmp_map[init_n_id]
            elif self.sample_type == 'node':
                n_id = init_n_id
                res_n_id = torch.arange(n_id.size(0), dtype=torch.long)
            else:
                raise ValueError('Unsupported value type {}'.format(self.sample_type))

            node_samples.append((n_id, res_n_id))

        return node_samples

    def __sample__(self, num_epoches):
        samples = []
        for _ in range(num_epoches):
            node_samples = self.__sample_nodes__()
            for n_id, res_n_id in node_samples:
                adj, e_id = self.adj.saint_subgraph(n_id)
                samples.append((n_id, e_id, adj, res_n_id))

        return samples

    def __compute_norm__(self):
        node_count = torch.zeros(self.N, dtype=torch.float)
        edge_count = torch.zeros(self.E, dtype=torch.float)

        if self.log:
            pbar = tqdm(total=self.sample_coverage)
            pbar.set_description('GraphSAINT Normalization')

        num_samples = len(self) * self.sample_coverage
        for _ in range(self.sample_coverage):
            samples = self.__sample__(1)

            for n_id, e_id, _, _ in samples:
                node_count[n_id] += 1
                edge_count[e_id] += 1

            if self.log:
                pbar.update(1)

        row, col, _ = self.adj.coo()

        edge_norm = (node_count[col] / edge_count).clamp_(0, 1e4)
        edge_norm[torch.isnan(edge_norm)] = 0.1

        node_count[node_count == 0] = 0.1
        node_norm = num_samples / (node_count * self.N)

        return node_norm, edge_norm

    def __get_data_from_sample__(self, sample):
        n_id, e_id, adj, res_n_id = sample
        data = self.data.__class__()
        data.num_nodes = n_id.size(0)
        row, col, value = adj.coo()
        data.edge_index = torch.stack([row, col], dim=0)
        data.edge_attr = value

        for key, item in self.data:
            if item.size(0) == self.N:
                data[key] = item[n_id]
            elif item.size(0) == self.E:
                data[key] = item[e_id]
            else:
                data[key] = item

        data.node_norm = self.node_norm[n_id]
        data.edge_norm = self.edge_norm[e_id]

        data.n_id = n_id
        data.res_n_id = res_n_id

        return data

    def __len__(self):
        return (self.N + self.batch_size-1) // self.batch_size

    def __iter__(self):
        for sample in self.__sample__(1):
            data = self.__get_data_from_sample__(sample)
            yield data
Esempio n. 11
0
class GraphSAINTSampler(torch.utils.data.DataLoader):
 
    def __init__(self, data, batch_size: int, num_steps: int = 1,
                 sample_coverage: int = 0, save_dir: Optional[str] = None,
                 log: bool = True, **kwargs):

        assert data.edge_index is not None
        assert 'node_norm' not in data
        assert 'edge_norm' not in data

        self.num_steps = num_steps
        self.__batch_size__ = batch_size
        self.sample_coverage = sample_coverage
        self.log = log

        self.N = N = data.num_nodes
        self.E = data.num_edges

        self.adj = SparseTensor(
            row=data.edge_index[0], col=data.edge_index[1],
            value=torch.arange(self.E, device=data.edge_index.device),
            sparse_sizes=(N, N))

        self.data = copy.copy(data)
        self.data.edge_index = None

        super(GraphSAINTSampler,
              self).__init__(self, batch_size=1, collate_fn=self.__collate__,
                             **kwargs)

        if self.sample_coverage > 0:
            path = osp.join(save_dir or '', self.__filename__)
            if save_dir is not None and osp.exists(path):  # pragma: no cover
                self.node_norm, self.edge_norm = torch.load(path)
            else:
                self.node_norm, self.edge_norm = self.__compute_norm__()
                if save_dir is not None:  # pragma: no cover
                    torch.save((self.node_norm, self.edge_norm), path)

    @property
    def __filename__(self):
        return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt'

    def __len__(self):
        return self.num_steps

    def __sample_nodes__(self, batch_size):
        raise NotImplementedError

    def __getitem__(self, idx):
        node_idx = self.__sample_nodes__(self.__batch_size__).unique()
        adj, _ = self.adj.saint_subgraph(node_idx)
        return node_idx, adj

    def __collate__(self, data_list):
        assert len(data_list) == 1
        node_idx, adj = data_list[0]

        data = self.data.__class__()
        data.num_nodes = node_idx.size(0)
        row, col, edge_idx = adj.coo()
        data.edge_index = torch.stack([row, col], dim=0)

        for key, item in self.data:
            if isinstance(item, torch.Tensor) and item.size(0) == self.N:
                data[key] = item[node_idx]
            elif isinstance(item, torch.Tensor) and item.size(0) == self.E:
                data[key] = item[edge_idx]
            else:
                data[key] = item

        if self.sample_coverage > 0:
            data.node_norm = self.node_norm[node_idx]
            data.edge_norm = self.edge_norm[edge_idx]

        return data

    def __compute_norm__(self):
        node_count = torch.zeros(self.N, dtype=torch.float)
        edge_count = torch.zeros(self.E, dtype=torch.float)

        loader = torch.utils.data.DataLoader(self, batch_size=200,
                                             collate_fn=lambda x: x,
                                             num_workers=self.num_workers)

        if self.log:  
            pbar = tqdm(total=self.N * self.sample_coverage)
            pbar.set_description('Compute GraphSAINT normalization')

        num_samples = total_sampled_nodes = 0
        while total_sampled_nodes < self.N * self.sample_coverage:
            for data in loader:
                for node_idx, adj in data:
                    edge_idx = adj.storage.value()
                    node_count[node_idx] += 1
                    edge_count[edge_idx] += 1
                    total_sampled_nodes += node_idx.size(0)

                    if self.log:  
                        pbar.update(node_idx.size(0))
            num_samples += 200

        if self.log:  
            pbar.close()

        row, _, edge_idx = self.adj.coo()
        t = torch.empty_like(edge_count).scatter_(0, edge_idx, node_count[row])
        edge_norm = (t / edge_count).clamp_(0, 1e4)
        edge_norm[torch.isnan(edge_norm)] = 0.1

        node_count[node_count == 0] = 0.1
        node_norm = num_samples / node_count / self.N

        return node_norm, edge_norm