Ejemplo n.º 1
0
class RandomNodeSampler(torch.utils.data.DataLoader):
    r"""A data loader that randomly samples nodes within a graph and returns
    their induced subgraph.
    .. note::
        For an example of using :obj:`RandomNodeSampler`, see
        `examples/ogbn_proteins_deepgcn.py
        <https://github.com/rusty1s/pytorch_geometric/blob/master/examples/
        ogbn_proteins_deepgcn.py>`_.
    Args:
        data (torch_geometric.data.Data): The graph data object.
        num_parts (int): The number of partitions.
        shuffle (bool, optional): If set to :obj:`True`, the data is reshuffled
            at every epoch (default: :obj:`False`).
        **kwargs (optional): Additional arguments of
            :class:`torch.utils.data.DataLoader`, such as :obj:`num_workers`.
    """
    def __init__(self, data, num_parts: int, shuffle: bool = False, **kwargs):
        assert data.edge_index is not None

        self.N = N = data.num_nodes
        self.E = data.num_edges

        self.adj = SparseTensor(row=data.edge_index[0],
                                col=data.edge_index[1],
                                value=torch.arange(
                                    self.E, device=data.edge_index.device),
                                sparse_sizes=(N, N))

        self.data = copy.copy(data)
        self.data.edge_index = None

        super(RandomNodeSampler, self).__init__(self,
                                                batch_size=1,
                                                sampler=RandomIndexSampler(
                                                    self.N, num_parts,
                                                    shuffle),
                                                collate_fn=self.__collate__,
                                                **kwargs)

    def __getitem__(self, idx):
        return idx

    def __collate__(self, node_idx):
        node_idx = node_idx[0]

        data = self.data.__class__()
        data.num_nodes = node_idx.size(0)
        adj, _ = self.adj.saint_subgraph(node_idx)
        row, col, edge_idx = adj.coo()
        data.edge_index = torch.stack([row, col], dim=0)

        for key, item in self.data:
            if isinstance(item, Tensor) and item.size(0) == self.N:
                data[key] = item[node_idx]
            elif isinstance(item, Tensor) and item.size(0) == self.E:
                data[key] = item[edge_idx]
            else:
                data[key] = item

        return data
Ejemplo n.º 2
0
def _filter(data, node_idx):
    """
    presumably data_n_id and new_n_id are sorted

    """
    new_data = Data()
    N = data.num_nodes
    E = data.edge_index.size(1)
    adj = SparseTensor(row=data.edge_index[0],
                       col=data.edge_index[1],
                       value=data.edge_attr,
                       sparse_sizes=(N, N))
    new_adj, edge_idx = adj.saint_subgraph(node_idx)
    row, col, value = new_adj.coo()

    for key, item in data:
        if item.size(0) == N:
            new_data[key] = item[node_idx]
        elif item.size(0) == E:
            new_data[key] = item[edge_idx]
        else:
            new_data[key] = item

    new_data.edge_index = torch.stack([row, col], dim=0)
    new_data.num_nodes = len(node_idx)
    new_data.edge_attr = value
    return new_data
Ejemplo n.º 3
0
class RandomNodeSampler(torch.utils.data.DataLoader):
    def __init__(self, data, num_parts: int, shuffle: bool = False, **kwargs):
        """
        Randomly sample nodes from a graph
        
        :param data: adjacency matrix
        :param num_parts: batch number
        :param shuffle: shuffle the index
        """
        assert data.edge_index is not None

        self.N = N = data.num_nodes
        self.E = data.num_edges

        self.adj = SparseTensor(row=data.edge_index[0],
                                col=data.edge_index[1],
                                value=torch.arange(
                                    self.E, device=data.edge_index.device),
                                sparse_sizes=(N, N))

        self.data = copy.copy(data)
        self.data.edge_index = None

        super(RandomNodeSampler, self).__init__(self,
                                                batch_size=1,
                                                sampler=RandomIndexSampler(
                                                    self.N, num_parts,
                                                    shuffle),
                                                collate_fn=self.__collate__,
                                                **kwargs)

    def __getitem__(self, idx):
        return idx

    def __collate__(self, node_idx):
        node_idx = node_idx[0]

        data = self.data.__class__()
        data.num_nodes = node_idx.size(0)
        adj, _ = self.adj.saint_subgraph(node_idx)
        row, col, edge_idx = adj.coo()
        data.edge_index = torch.stack([row, col], dim=0)

        for key, item in self.data:
            if item.size(0) == self.N:
                data[key] = item[node_idx]
            elif item.size(0) == self.E:
                data[key] = item[edge_idx]
            else:
                data[key] = item

        return data
Ejemplo n.º 4
0
class GraphSAINTSampler(object):
    r"""The GraphSAINT sampler base class from the `"GraphSAINT: Graph
    Sampling Based Inductive Learning Method"
    <https://arxiv.org/abs/1907.04931>`_ paper.
    Given a graph in a :obj:`data` object, this class samples nodes and
    constructs subgraphs that can be processed in a mini-batch fashion.
    Normalization coefficients for each mini-batch are given via
    :obj:`node_norm` and :obj:`edge_norm` data attributes.

    .. note::

        See :class:`torch_geometric.data.GraphSAINTNodeSampler`,
        :class:`torch_geometric.data.GraphSAINTEdgeSampler` and
        :class:`torch_geometric.data.GraphSAINTRandomWalkSampler` for
        currently supported samplers.
        For an example of using GraphSAINT sampling, see
        `examples/graph_saint.py <https://github.com/rusty1s/pytorch_geometric/
        blob/master/examples/graph_saint.py>`_.

    Args:
        data (torch_geometric.data.Data): The graph data object.
        batch_size (int): The approximate number of samples per batch to load.
        num_steps (int, optional): The number of iterations.
            (default: :obj:`1`)
        sample_coverage (int): How many samples per node should be used to
            compute normalization statistics. (default: :obj:`50`)
        save_dir (string, optional): If set, will save normalization
            statistics to the :obj:`save_dir` directory for faster re-use.
            (default: :obj:`None`)
        num_workers (int, optional): How many subprocesses to use for data
            sampling.
            :obj:`0` means that the data will be sampled in the main process.
            (default: :obj:`0`)
    """

    def __init__(self, data, batch_size, num_steps=1, sample_coverage=50,
                 save_dir=None, num_workers=0):
        assert data.edge_index is not None
        assert 'node_norm' not in data
        assert 'edge_norm' not in data

        self.N = N = data.num_nodes
        self.E = data.num_edges

        self.adj = SparseTensor(row=data.edge_index[0], col=data.edge_index[1],
                                value=data.edge_attr, sparse_sizes=(N, N))

        self.data = copy.copy(data)
        self.data.edge_index = None
        self.data.edge_attr = None

        self.batch_size = batch_size
        self.num_steps = num_steps
        self.sample_coverage = sample_coverage
        self.num_workers = num_workers
        self.__count__ = 0

        if self.num_workers > 0:
            self.__sample_queue__ = Queue()
            self.__sample_workers__ = []
            for _ in range(self.num_workers):
                worker = Process(target=self.__put_sample__,
                                 args=(self.__sample_queue__, ))
                worker.daemon = True
                worker.start()
                self.__sample_workers__.append(worker)

        path = osp.join(save_dir or '', self.__filename__)
        if save_dir is not None and osp.exists(path):
            self.node_norm, self.edge_norm = torch.load(path)
        else:
            self.node_norm, self.edge_norm = self.__compute_norm__()
            if save_dir is not None:
                torch.save((self.node_norm, self.edge_norm), path)

        if self.num_workers > 0:
            self.__data_queue__ = Queue()
            self.__data_workers__ = []
            for _ in range(self.num_workers):
                worker = Process(target=self.__put_data__,
                                 args=(self.__data_queue__, ))
                worker.daemon = True
                worker.start()
                self.__data_workers__.append(worker)

    @property
    def __filename__(self):
        return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt'

    def __sample_nodes__(self, num_examples):
        raise NotImplementedError

    def __sample__(self, num_examples):
        node_samples = self.__sample_nodes__(num_examples)

        samples = []
        for node_idx in node_samples:
            node_idx = node_idx.unique()
            adj, edge_idx = self.adj.saint_subgraph(node_idx)
            samples.append((node_idx, edge_idx, adj))
        return samples

    def __compute_norm__(self):
        node_count = torch.zeros(self.N, dtype=torch.float)
        edge_count = torch.zeros(self.E, dtype=torch.float)

        num_samples = total_sampled_nodes = 0
        pbar = tqdm(total=self.N * self.sample_coverage)
        pbar.set_description('Compute GraphSAINT normalization')
        while total_sampled_nodes < self.N * self.sample_coverage:
            num_sampled_nodes = 0
            if self.num_workers > 0:
                for _ in range(200):
                    node_idx, edge_idx, _ = self.__sample_queue__.get()
                    node_count[node_idx] += 1
                    edge_count[edge_idx] += 1
                    num_sampled_nodes += node_idx.size(0)
            else:
                samples = self.__sample__(200)
                for node_idx, edge_idx, _ in samples:
                    node_count[node_idx] += 1
                    edge_count[edge_idx] += 1
                    num_sampled_nodes += node_idx.size(0)
            pbar.update(num_sampled_nodes)
            total_sampled_nodes += num_sampled_nodes
            num_samples += 200
        pbar.close()

        row, col, _ = self.adj.coo()

        edge_norm = (node_count[col] / edge_count).clamp_(0, 1e4)
        edge_norm[torch.isnan(edge_norm)] = 0.1

        node_count[node_count == 0] = 0.1
        node_norm = num_samples / node_count / self.N

        return node_norm, edge_norm

    def __get_data_from_sample__(self, sample):
        node_idx, edge_idx, adj = sample

        data = self.data.__class__()
        data.num_nodes = node_idx.size(0)
        row, col, value = adj.coo()
        data.edge_index = torch.stack([row, col], dim=0)
        data.edge_attr = value

        for key, item in self.data:
            if item.size(0) == self.N:
                data[key] = item[node_idx]
            elif item.size(0) == self.E:
                data[key] = item[edge_idx]
            else:
                data[key] = item

        data.node_norm = self.node_norm[node_idx]
        data.edge_norm = self.edge_norm[edge_idx]

        return data

    def __put_sample__(self, queue):
        while True:
            sample = self.__sample__(1)[0]
            queue.put(sample)

    def __put_data__(self, queue):
        while True:
            sample = self.__sample_queue__.get()
            data = self.__get_data_from_sample__(sample)
            queue.put(data)

    def __next__(self):
        if self.__count__ < len(self):
            self.__count__ += 1
            if self.num_workers > 0:
                data = self.__data_queue__.get()
            else:
                sample = self.__sample__(1)[0]
                data = self.__get_data_from_sample__(sample)
            return data
        else:
            raise StopIteration

    def __len__(self):
        return self.num_steps

    def __iter__(self):
        self.__count__ = 0
        return self
Ejemplo n.º 5
0
class GraphSAINTSampler(torch.utils.data.DataLoader):
    r"""The GraphSAINT sampler base class from the `"GraphSAINT: Graph
    Sampling Based Inductive Learning Method"
    <https://arxiv.org/abs/1907.04931>`_ paper.
    Given a graph in a :obj:`data` object, this class samples nodes and
    constructs subgraphs that can be processed in a mini-batch fashion.
    Normalization coefficients for each mini-batch are given via
    :obj:`node_norm` and :obj:`edge_norm` data attributes.

    .. note::

        See :class:`torch_geometric.data.GraphSAINTNodeSampler`,
        :class:`torch_geometric.data.GraphSAINTEdgeSampler` and
        :class:`torch_geometric.data.GraphSAINTRandomWalkSampler` for
        currently supported samplers.
        For an example of using GraphSAINT sampling, see
        `examples/graph_saint.py <https://github.com/rusty1s/pytorch_geometric/
        blob/master/examples/graph_saint.py>`_.

    Args:
        data (torch_geometric.data.Data): The graph data object.
        batch_size (int): The approximate number of samples per batch.
        num_steps (int, optional): The number of iterations per epoch.
            (default: :obj:`1`)
        sample_coverage (int): How many samples per node should be used to
            compute normalization statistics. (default: :obj:`0`)
        save_dir (string, optional): If set, will save normalization
            statistics to the :obj:`save_dir` directory for faster re-use.
            (default: :obj:`None`)
        log (bool, optional): If set to :obj:`False`, will not log any
            pre-processing progress. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size` or
            :obj:`num_workers`.
    """
    def __init__(self,
                 data,
                 batch_size: int,
                 num_steps: int = 1,
                 sample_coverage: int = 0,
                 save_dir: Optional[str] = None,
                 log: bool = True,
                 **kwargs):

        assert data.edge_index is not None
        assert 'node_norm' not in data
        assert 'edge_norm' not in data

        self.num_steps = num_steps
        self.__batch_size__ = batch_size
        self.sample_coverage = sample_coverage
        self.log = log

        self.N = N = data.num_nodes
        self.E = data.num_edges

        self.adj = SparseTensor(row=data.edge_index[0],
                                col=data.edge_index[1],
                                value=torch.arange(
                                    self.E, device=data.edge_index.device),
                                sparse_sizes=(N, N))

        self.data = copy.copy(data)
        self.data.edge_index = None

        super(GraphSAINTSampler, self).__init__(self,
                                                batch_size=1,
                                                collate_fn=self.__collate__,
                                                **kwargs)

        if self.sample_coverage > 0:
            path = osp.join(save_dir or '', self.__filename__)
            if save_dir is not None and osp.exists(path):  # pragma: no cover
                self.node_norm, self.edge_norm = torch.load(path)
            else:
                self.node_norm, self.edge_norm = self.__compute_norm__()
                if save_dir is not None:  # pragma: no cover
                    torch.save((self.node_norm, self.edge_norm), path)

    @property
    def __filename__(self):
        return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt'

    def __len__(self):
        return self.num_steps

    def __sample_nodes__(self, batch_size):
        raise NotImplementedError

    def __getitem__(self, idx):
        node_idx = self.__sample_nodes__(self.__batch_size__).unique()
        adj, _ = self.adj.saint_subgraph(node_idx)
        return node_idx, adj

    def __collate__(self, data_list):
        assert len(data_list) == 1
        node_idx, adj = data_list[0]

        data = self.data.__class__()
        data.num_nodes = node_idx.size(0)
        row, col, edge_idx = adj.coo()
        data.edge_index = torch.stack([row, col], dim=0)

        for key, item in self.data:
            if isinstance(item, torch.Tensor) and item.size(0) == self.N:
                data[key] = item[node_idx]
            elif isinstance(item, torch.Tensor) and item.size(0) == self.E:
                data[key] = item[edge_idx]
            else:
                data[key] = item

        if self.sample_coverage > 0:
            data.node_norm = self.node_norm[node_idx]
            data.edge_norm = self.edge_norm[edge_idx]

        return data

    def __compute_norm__(self):
        node_count = torch.zeros(self.N, dtype=torch.float)
        edge_count = torch.zeros(self.E, dtype=torch.float)

        loader = torch.utils.data.DataLoader(self,
                                             batch_size=200,
                                             collate_fn=lambda x: x,
                                             num_workers=self.num_workers)

        if self.log:  # pragma: no cover
            pbar = tqdm(total=self.N * self.sample_coverage)
            pbar.set_description('Compute GraphSAINT normalization')

        num_samples = total_sampled_nodes = 0
        while total_sampled_nodes < self.N * self.sample_coverage:
            for data in loader:
                for node_idx, adj in data:
                    edge_idx = adj.storage.value()
                    node_count[node_idx] += 1
                    edge_count[edge_idx] += 1
                    total_sampled_nodes += node_idx.size(0)

                    if self.log:  # pragma: no cover
                        pbar.update(node_idx.size(0))
            num_samples += self.num_steps

        if self.log:  # pragma: no cover
            pbar.close()

        row, _, edge_idx = self.adj.coo()
        t = torch.empty_like(edge_count).scatter_(0, edge_idx, node_count[row])
        edge_norm = (t / edge_count).clamp_(0, 1e4)
        edge_norm[torch.isnan(edge_norm)] = 0.1

        node_count[node_count == 0] = 0.1
        node_norm = num_samples / node_count / self.N

        return node_norm, edge_norm
Ejemplo n.º 6
0
class Pruner(object):
    def __init__(self, edge_index, split_idx, prune_set='train', ratio=0.9):
        self.N = N = int(edge_index.max() + 1)
        self.E = edge_index.size(1)
        self.edge_index = edge_index

        self.adj = SparseTensor(row=edge_index[0],
                                col=edge_index[1],
                                value=torch.arange(edge_index.size(1)),
                                sparse_sizes=(N, N),
                                is_sorted=False)
        if prune_set == 'train':
            self.train_idx = split_idx['train']
        else:
            self.train_idx = torch.cat(
                [split_idx['train'], split_idx['valid']])
        subadj, _ = self.adj.saint_subgraph(self.train_idx)
        # subadj = self.adj.to_dense()[self.train_idx][:,self.train_idx].view(-1)
        _, _, e_idx = subadj.coo()
        self.train_e_idx = e_idx.squeeze().long()
        self.train_edge_index = self.edge_index[:, self.train_e_idx]
        self.rest_e_idx = torch.LongTensor(
            list(set(range(self.E)) - set(self.train_e_idx.tolist())))
        self.ratio = ratio
        self.loss = torch.zeros(N)
        self.times = 0

    def prune(self, naive=False):
        i = self.times
        diff_loss = torch.abs(self.loss[self.train_edge_index[0]] -
                              self.loss[self.train_edge_index[1]])
        # print(diff_loss.nonzero().size())
        # print(int(len(diff_loss)*ratio))
        _, mask1 = torch.topk(diff_loss,
                              int(len(diff_loss) * self.ratio),
                              largest=False)

        newE = self.train_edge_index.size(1)
        mask2 = torch.randperm(newE)[:int(newE * self.ratio)]
        torch.save(self.train_edge_index,
                   f'./savept/pre_prune_edges_{self.ratio ** i:.4f}.pt')

        torch.save(
            mask1, f'./savept/smart_mask_prune_edges_{self.ratio ** i:.4f}.pt')
        torch.save(
            mask2, f'./savept/naive_mask_prune_edges_{self.ratio ** i:.4f}.pt')
        torch.save(self.loss, f'./savept/loss_{self.ratio ** i:.4f}.pt')
        self.times += 1
        if naive == False:
            mask = mask1
        else:
            mask = mask2
        self.train_e_idx = self.train_e_idx[mask]
        self.train_edge_index = self.train_edge_index[:, mask]
        self.edge_index = self.edge_index[:,
                                          torch.cat([
                                              self.train_e_idx, self.rest_e_idx
                                          ])]
        self.train_e_idx = torch.arange(self.train_e_idx.size(0))
        self.rest_e_idx = torch.arange(
            self.train_e_idx.size(0),
            self.train_e_idx.size(0) + self.rest_e_idx.size(0))
        self.E = self.edge_index.size(1)
        print(
            f'****************trainE : {self.train_e_idx.size(0)} ,restE:{self.rest_e_idx.size(0)}, totalE:{self.E}'
        )

    def update_loss(self, loss):
        self.loss = loss
Ejemplo n.º 7
0
class MySAINTSampler(object):
    r"""A new random-walk sampler for GraphSAINT that samples initial nodes
    by iterating over node permutations. The benefit is that we can leverage
    this sampler for subgraph-based inference.

    Args:
        data (torch_geometric.data.Data): The graph data object.
        batch_size (int): The number of walks to sample per batch.
        walk_length (int): The length of each random walk.
        sample_coverage (int): How many samples per node should be used to
            compute normalization statistics. (default: :obj:`50`)
        save_dir (string, optional): If set, will save normalization
            statistics to the :obj:`save_dir` directory for faster re-use.
            (default: :obj:`None`)
        log (bool, optional): If set to :obj:`False`, will not log any
            progress. (default: :obj:`True`)
    """

    def __init__(self, data, batch_size, sample_type='random_walk', walk_length=2, sample_coverage=50,
                 save_dir=None, log=True):
        assert data.edge_index is not None
        assert 'node_norm' not in data
        assert 'edge_norm' not in data
        assert sample_type in ('node', 'random_walk')

        self.N = N = data.num_nodes
        self.E = data.num_edges

        self.adj = SparseTensor(row=data.edge_index[0], col=data.edge_index[1],
                                value=data.edge_attr, sparse_sizes=(N, N))

        self.data = copy.copy(data)
        self.data.edge_index = None
        self.data.edge_attr = None

        self.sample_type = sample_type
        self.batch_size = batch_size
        # self.num_steps = num_steps
        self.walk_length = walk_length
        self.sample_coverage = sample_coverage
        self.log = log

        path = osp.join(save_dir or '', self.__filename__)
        if save_dir is not None and osp.exists(path):  # pragma: no cover
            self.node_norm, self.edge_norm = torch.load(path)
        else:
            self.node_norm, self.edge_norm = self.__compute_norm__()
            if save_dir is not None:  # pragma: no cover
                torch.save((self.node_norm, self.edge_norm), path)

    @property
    def __filename__(self):
        return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt'

    def __sample_nodes__(self):
        """Sampling initial nodes by iterating over the random permutation of nodes"""
        tmp_map = torch.arange(self.N, dtype=torch.long)
        all_n_id = torch.randperm(self.N, dtype=torch.long)
        node_samples = []
        for s_id in range(0, self.N, self.batch_size):
            init_n_id = all_n_id[s_id:s_id+self.batch_size]  # [batch_size]

            if self.sample_type == 'random_walk':
                n_id = self.adj.random_walk(init_n_id, self.walk_length)  # [batch_size, walk_length+1]
                n_id = n_id.flatten().unique()  # [num_nodes_in_subgraph]
                tmp_map[n_id] = torch.arange(n_id.size(0), dtype=torch.long)
                res_n_id = tmp_map[init_n_id]
            elif self.sample_type == 'node':
                n_id = init_n_id
                res_n_id = torch.arange(n_id.size(0), dtype=torch.long)
            else:
                raise ValueError('Unsupported value type {}'.format(self.sample_type))

            node_samples.append((n_id, res_n_id))

        return node_samples

    def __sample__(self, num_epoches):
        samples = []
        for _ in range(num_epoches):
            node_samples = self.__sample_nodes__()
            for n_id, res_n_id in node_samples:
                adj, e_id = self.adj.saint_subgraph(n_id)
                samples.append((n_id, e_id, adj, res_n_id))

        return samples

    def __compute_norm__(self):
        node_count = torch.zeros(self.N, dtype=torch.float)
        edge_count = torch.zeros(self.E, dtype=torch.float)

        if self.log:
            pbar = tqdm(total=self.sample_coverage)
            pbar.set_description('GraphSAINT Normalization')

        num_samples = len(self) * self.sample_coverage
        for _ in range(self.sample_coverage):
            samples = self.__sample__(1)

            for n_id, e_id, _, _ in samples:
                node_count[n_id] += 1
                edge_count[e_id] += 1

            if self.log:
                pbar.update(1)

        row, col, _ = self.adj.coo()

        edge_norm = (node_count[col] / edge_count).clamp_(0, 1e4)
        edge_norm[torch.isnan(edge_norm)] = 0.1

        node_count[node_count == 0] = 0.1
        node_norm = num_samples / (node_count * self.N)

        return node_norm, edge_norm

    def __get_data_from_sample__(self, sample):
        n_id, e_id, adj, res_n_id = sample
        data = self.data.__class__()
        data.num_nodes = n_id.size(0)
        row, col, value = adj.coo()
        data.edge_index = torch.stack([row, col], dim=0)
        data.edge_attr = value

        for key, item in self.data:
            if item.size(0) == self.N:
                data[key] = item[n_id]
            elif item.size(0) == self.E:
                data[key] = item[e_id]
            else:
                data[key] = item

        data.node_norm = self.node_norm[n_id]
        data.edge_norm = self.edge_norm[e_id]

        data.n_id = n_id
        data.res_n_id = res_n_id

        return data

    def __len__(self):
        return (self.N + self.batch_size-1) // self.batch_size

    def __iter__(self):
        for sample in self.__sample__(1):
            data = self.__get_data_from_sample__(sample)
            yield data
Ejemplo n.º 8
0
class GraphSAINTSampler(torch.utils.data.DataLoader):
 
    def __init__(self, data, batch_size: int, num_steps: int = 1,
                 sample_coverage: int = 0, save_dir: Optional[str] = None,
                 log: bool = True, **kwargs):

        assert data.edge_index is not None
        assert 'node_norm' not in data
        assert 'edge_norm' not in data

        self.num_steps = num_steps
        self.__batch_size__ = batch_size
        self.sample_coverage = sample_coverage
        self.log = log

        self.N = N = data.num_nodes
        self.E = data.num_edges

        self.adj = SparseTensor(
            row=data.edge_index[0], col=data.edge_index[1],
            value=torch.arange(self.E, device=data.edge_index.device),
            sparse_sizes=(N, N))

        self.data = copy.copy(data)
        self.data.edge_index = None

        super(GraphSAINTSampler,
              self).__init__(self, batch_size=1, collate_fn=self.__collate__,
                             **kwargs)

        if self.sample_coverage > 0:
            path = osp.join(save_dir or '', self.__filename__)
            if save_dir is not None and osp.exists(path):  # pragma: no cover
                self.node_norm, self.edge_norm = torch.load(path)
            else:
                self.node_norm, self.edge_norm = self.__compute_norm__()
                if save_dir is not None:  # pragma: no cover
                    torch.save((self.node_norm, self.edge_norm), path)

    @property
    def __filename__(self):
        return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt'

    def __len__(self):
        return self.num_steps

    def __sample_nodes__(self, batch_size):
        raise NotImplementedError

    def __getitem__(self, idx):
        node_idx = self.__sample_nodes__(self.__batch_size__).unique()
        adj, _ = self.adj.saint_subgraph(node_idx)
        return node_idx, adj

    def __collate__(self, data_list):
        assert len(data_list) == 1
        node_idx, adj = data_list[0]

        data = self.data.__class__()
        data.num_nodes = node_idx.size(0)
        row, col, edge_idx = adj.coo()
        data.edge_index = torch.stack([row, col], dim=0)

        for key, item in self.data:
            if isinstance(item, torch.Tensor) and item.size(0) == self.N:
                data[key] = item[node_idx]
            elif isinstance(item, torch.Tensor) and item.size(0) == self.E:
                data[key] = item[edge_idx]
            else:
                data[key] = item

        if self.sample_coverage > 0:
            data.node_norm = self.node_norm[node_idx]
            data.edge_norm = self.edge_norm[edge_idx]

        return data

    def __compute_norm__(self):
        node_count = torch.zeros(self.N, dtype=torch.float)
        edge_count = torch.zeros(self.E, dtype=torch.float)

        loader = torch.utils.data.DataLoader(self, batch_size=200,
                                             collate_fn=lambda x: x,
                                             num_workers=self.num_workers)

        if self.log:  
            pbar = tqdm(total=self.N * self.sample_coverage)
            pbar.set_description('Compute GraphSAINT normalization')

        num_samples = total_sampled_nodes = 0
        while total_sampled_nodes < self.N * self.sample_coverage:
            for data in loader:
                for node_idx, adj in data:
                    edge_idx = adj.storage.value()
                    node_count[node_idx] += 1
                    edge_count[edge_idx] += 1
                    total_sampled_nodes += node_idx.size(0)

                    if self.log:  
                        pbar.update(node_idx.size(0))
            num_samples += 200

        if self.log:  
            pbar.close()

        row, _, edge_idx = self.adj.coo()
        t = torch.empty_like(edge_count).scatter_(0, edge_idx, node_count[row])
        edge_norm = (t / edge_count).clamp_(0, 1e4)
        edge_norm[torch.isnan(edge_norm)] = 0.1

        node_count[node_count == 0] = 0.1
        node_norm = num_samples / node_count / self.N

        return node_norm, edge_norm
Ejemplo n.º 9
0
class RandomNodeSampler(torch.utils.data.DataLoader):
    r"""A data loader that randomly samples nodes within a graph and returns
    their induced subgraph.
    .. note::
        For an example of using :obj:`RandomNodeSampler`, see
        `examples/ogbn_proteins_deepgcn.py
        <https://github.com/rusty1s/pytorch_geometric/blob/master/examples/
        ogbn_proteins_deepgcn.py>`_.
    Args:
        data (torch_geometric.data.Data): The graph data object.
        num_parts (int): The number of partitions.
        shuffle (bool, optional): If set to :obj:`True`, the data is reshuffled
            at every epoch (default: :obj:`False`).
        **kwargs (optional): Additional arguments of
            :class:`torch.utils.data.DataLoader`, such as :obj:`num_workers`.
    """
    def __init__(self, data, num_parts: int, shuffle: bool = False, split_idx=None, num_edges=None, prune=False, prune_set='train', **kwargs):
        assert data.edge_index is not None

        self.N = N = data.num_nodes
        if num_edges != None:
            self.E = num_edges
        else:
            self.E = data.num_edges
        self.split_idx = split_idx
        self.adj = SparseTensor(
            row=data.edge_index[0], col=data.edge_index[1],
            value=torch.arange(self.E, device=data.edge_index.device),
            sparse_sizes=(N, N))
        self.data = copy.copy(data)
        # self.data.edge_index = edge_index
        super(RandomNodeSampler, self).__init__(
            self, batch_size=1,
            sampler=RandomIndexSampler(self.N, num_parts, shuffle),
            collate_fn=self.__collate__, **kwargs)
        if prune == True:
            self.edge_index = data.edge_index
            if prune_set == 'train':
                self.train_idx = self.split_idx['train']
            else:
                self.train_idx = torch.cat([self.split_idx['train'], self.split_idx['valid']])
            subadj, _ = self.adj.saint_subgraph(self.train_idx)
            # subadj = self.adj.to_dense()[self.train_idx][:,self.train_idx].view(-1)
            _,_,e_idx = subadj.coo()
            self.train_e_idx = e_idx.squeeze().long()
            self.train_edge_index = self.edge_index[:, self.train_e_idx] 
            self.rest_e_idx = torch.LongTensor(list(set(range(self.E))  - set(self.train_e_idx.tolist())))

        self.times = 0

    def __getitem__(self, idx):
        return idx

    def prune(self, loss, ratio=0, method='ada',globe=False, savept=False):
        if globe == False:
            if method=='ada':
                diff_loss = torch.abs(loss[self.train_edge_index[0]] - loss[self.train_edge_index[1]])
                print(diff_loss.size(0))
                #print(diff_loss.nonzero().size())
                # print(int(len(diff_loss)*ratio))
                _, mask = torch.topk(diff_loss, int(len(diff_loss)*ratio), largest=False)
            elif method=='random':
                newE =self.train_edge_index.size(1)
                mask = torch.randperm(newE)[:int(newE*ratio)]
            elif method == 'naive':
                degrees = scatter(torch.ones(self.train_edge_index.size(1)), self.train_edge_index[0])
                thold = (degrees.max()*ratio).long()
                prunemask = (degrees>thold).nonzero()
                mymask = torch.ones(self.train_edge_index.size(1))
                taway = []
                for pru in prunemask:
                    p_eid = (self.train_edge_index[0]==pru.squeeze()).nonzero().squeeze()
                    p = p_eid[torch.randperm(p_eid.size(0))][thold:]
                    taway.append(p)
                nonmask = torch.cat(taway)
                mask = torch.ones(self.train_edge_index.size(1), dtype=torch.bool)
                mask[nonmask] = False
                    
           
            self.train_e_idx = self.train_e_idx[mask]
            self.train_edge_index = self.train_edge_index[:, mask]
            # self.edge_attr = self.data.edge_attr[torch.cat([self.train_e_idx, self.rest_e_idx])]
            self.edge_index = self.edge_index[:,torch.cat([self.train_e_idx, self.rest_e_idx])]

            # print(self.data.edge_attr.size(), self.data.edge_index.size())
            print('len',len(self.train_e_idx),len(self.rest_e_idx), self.train_edge_index.size(),self.edge_index.size())
            self.data.edge_attr = self.data.edge_attr[torch.cat([self.train_e_idx, self.rest_e_idx])]
            self.data.edge_index = self.data.edge_index[:,torch.cat([self.train_e_idx, self.rest_e_idx])]
            self.train_e_idx = torch.arange(self.train_e_idx.size(0))
            self.rest_e_idx = torch.arange(self.train_e_idx.size(0),self.train_e_idx.size(0) + self.rest_e_idx.size(0))
            self.E = self.edge_index.size(1)
            self.adj = SparseTensor(
                row=self.edge_index[0], col=self.edge_index[1],
                value=torch.arange(self.E, device=self.edge_index.device),
                sparse_sizes=(self.N, self.N)) 
            # if savept==True:
            #     torch.save(self.train_edge_index, f'./savept/edge_index_protein.pt')
            #     # torch.save(mask1, f'./savept/p_smart_mask_prune_edges_{ratio ** i:.4f}.pt')
            #     # torch.save(mask2, f'./savept/p_naive_mask_prune_edges_{ratio ** i:.4f}.pt')
            #     torch.save(loss, f'./savept/p_loss_protein.pt')
           
        else:
            if method == 'random':
                mask = torch.randperm(self.E)[:int(self.E*ratio)]
            elif method == 'naive':
                degrees = scatter(torch.ones(self.data.edge_index.size(1)), self.data.edge_index[0])
                thold = (degrees.max()*ratio).long()
                prunemask = (degrees>thold).nonzero().squeeze()
                mymask = torch.ones(self.data.edge_index.size(1))
                taway = []
                for pru in prunemask:
                    p_eid = (self.data.edge_index[0]==pru).nonzero().squeeze()
                    p = p_eid[torch.randperm(p_eid.size(0))][thold:]
                    taway.append(p)
                nonmask = torch.cat(taway)
                mask = torch.ones(self.data.edge_index.size(1), dtype=torch.bool)
                mask[nonmask] = False

            self.data.edge_index = self.data.edge_index[:,mask]
            self.data.edge_attr = self.data.edge_attr[mask]
            self.E = self.data.edge_index.size(1)
            self.adj = SparseTensor(
            row=self.data.edge_index[0], col=self.data.edge_index[1],
            value=torch.arange(self.E, device=self.data.edge_index.device),
            sparse_sizes=(self.N, self.N)) 
    
    
    
    
    

    
    # def prune(self, loss, ratio, naive=False,cutoff=False, savept=False, random=False):
    #     #p_loss = loss[self.train_idx]
    #     i = self.times
    #     if naive== False:        
    #         diff_loss = torch.abs(loss[self.train_edge_index[0]] - loss[self.train_edge_index[1]])
    #         # print(diff_loss.nonzero().size())
    #         # print(int(len(diff_loss)*ratio))
    #         _, mask = torch.topk(diff_loss, int(len(diff_loss) * ratio), largest=False)
    #     else:
    #         newE =self.train_edge_index.size(1)
    #         mask = torch.randperm(newE)[:int(newE * ratio)]
    #     if savept==True:
    #         torch.save(self.train_edge_index, f'./savept/edge_index_protein.pt')
    #         # torch.save(mask1, f'./savept/p_smart_mask_prune_edges_{ratio ** i:.4f}.pt')
    #         # torch.save(mask2, f'./savept/p_naive_mask_prune_edges_{ratio ** i:.4f}.pt')
    #         torch.save(loss, f'./savept/p_loss_protein.pt')
    #     # self.train_edge_index = self.train_edge_index[:,mask]
    #     # edge_index = torch.cat([self.train_edge_index,self.rest_edge_index], dim=1)
    #     # self.data.edge_index = edge_index
    #     self.train_e_idx = self.train_e_idx[mask]
    #     self.train_edge_index = self.train_edge_index[:, mask]
    #     self.data.edge_attr = self.data.edge_attr[torch.cat([self.train_e_idx, self.rest_e_idx])]
    #     self.data.edge_index = self.data.edge_index[:,torch.cat([self.train_e_idx, self.rest_e_idx])]

    #     # print(self.data.edge_attr.size(), self.data.edge_index.size())
    #     self.train_e_idx = torch.arange(self.train_e_idx.size(0))
    #     self.rest_e_idx = torch.arange(self.train_e_idx.size(0),self.train_e_idx.size(0) + self.rest_e_idx.size(0))
    #     # print(len(self.train_e_idx),len(self.rest_e_idx), self.train_edge_index.size(),self.data.edge_index.size())
    #     self.E = self.data.num_edges
    #     self.adj = SparseTensor(
    #         row=self.data.edge_index[0], col=self.data.edge_index[1],
    #         value=torch.arange(self.E, device=self.data.edge_index.device),
    #         sparse_sizes=(self.N, self.N)) 
    
    def __collate__(self, node_idx):
        node_idx = node_idx[0]
        data = self.data.__class__()
        data.num_nodes = node_idx.size(0)
        data.n_id = node_idx

        adj, _ = self.adj.saint_subgraph(node_idx)
        row, col, edge_idx = adj.coo()
        data.edge_index = torch.stack([row, col], dim=0)

        for key, item in self.data:
            if item.size(0) == self.N:
                data[key] = item[node_idx]
            elif item.size(0) == self.E and key != 'edge_index':
                data[key] = item[edge_idx]
            elif key!= 'edge_index':
                data[key] = item

        return data
Ejemplo n.º 10
0
class GraphSAINTSampler(torch.utils.data.DataLoader):
    r"""The GraphSAINT sampler base class from the `"GraphSAINT: Graph
    Sampling Based Inductive Learning Method"
    <https://arxiv.org/abs/1907.04931>`_ paper.
    Given a graph in a :obj:`data` object, this class samples nodes and
    constructs subgraphs that can be processed in a mini-batch fashion.
    Normalization coefficients for each mini-batch are given via
    :obj:`node_norm` and :obj:`edge_norm` data attributes.
    .. note::
        See :class:`torch_geometric.data.GraphSAINTNodeSampler`,
        :class:`torch_geometric.data.GraphSAINTEdgeSampler` and
        :class:`torch_geometric.data.GraphSAINTRandomWalkSampler` for
        currently supported samplers.
        For an example of using GraphSAINT sampling, see
        `examples/graph_saint.py <https://github.com/rusty1s/pytorch_geometric/
        blob/master/examples/graph_saint.py>`_.
    Args:
        data (torch_geometric.data.Data): The graph data object.
        batch_size (int): The approximate number of samples per batch.
        num_steps (int, optional): The number of iterations per epoch.
            (default: :obj:`1`)
        sample_coverage (int): How many samples per node should be used to
            compute normalization statistics. (default: :obj:`0`)
        save_dir (string, optional): If set, will save normalization
            statistics to the :obj:`save_dir` directory for faster re-use.
            (default: :obj:`None`)
        log (bool, optional): If set to :obj:`False`, will not log any
            pre-processing progress. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size` or
            :obj:`num_workers`.
    """
    def __init__(self, data, batch_size: int, num_steps: int = 1,
                 sample_coverage: int = 0, save_dir: Optional[str] = None,
                 log: bool = True, prune = False, prune_set='train', prune_type='adaptive',
                 **kwargs):

        assert data.edge_index is not None
        assert 'node_norm' not in data
        assert 'edge_norm' not in data

        self.num_steps = num_steps
        self.__batch_size__ = batch_size
        self.sample_coverage = sample_coverage
        self.log = log

        self.N = N = data.num_nodes
        self.E = data.num_edges

        self.adj = SparseTensor(
            row=data.edge_index[0], col=data.edge_index[1],
            value=torch.arange(self.E, device=data.edge_index.device),
            sparse_sizes=(N, N))

        self.data = copy.copy(data)

        super(GraphSAINTSampler,
              self).__init__(self, batch_size=1, collate_fn=self.__collate__,
                             **kwargs)
        if prune == True:
            if prune_set == 'train':
                self.train_idx = self.data.train_mask.nonzero(as_tuple=False).squeeze()
            else:
                self.train_idx = (self.data.train_mask + self.data.valid_mask).nonzero(as_tuple=False).squeeze()                
            subadj, _ = self.adj.saint_subgraph(self.train_idx)
            # subadj = self.adj.to_dense()[self.train_idx][:,self.train_idx].view(-1)
            _,_,e_idx = subadj.coo()
            self.train_e_idx = e_idx.squeeze().long()
            self.train_edge_index = self.data.edge_index[:, self.train_e_idx] 
            self.rest_e_idx = torch.LongTensor(list(set(range(self.E))  - set(self.train_e_idx.tolist())))

        # if self.sample_coverage > 0:
        #     path = osp.join(save_dir or '', self.__filename__)
        #     if save_dir is not None and osp.exists(path):  # pragma: no cover
        #         self.node_norm, self.edge_norm = torch.load(path)
        #     else:
        #         self.node_norm, self.edge_norm = self.__compute_norm__()
        #         if save_dir is not None:  # pragma: no cover
        #             torch.save((self.node_norm, self.edge_norm), path)

    @property
    def __filename__(self):
        return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt'

    def __len__(self):
        return self.num_steps

    def __sample_nodes__(self, batch_size):
        raise NotImplementedError

    def __getitem__(self, idx):
        node_idx = self.__sample_nodes__(self.__batch_size__).unique()
        adj, _ = self.adj.saint_subgraph(node_idx)
        return node_idx, adj

    def prune(self, loss, ratio=0, method='ada',glob=False):
        if glob == False:
            if method=='ada':
                diff_loss = torch.abs(loss[self.train_edge_index[0]] - loss[self.train_edge_index[1]])
                #print(diff_loss.nonzero().size())
                # print(int(len(diff_loss)*ratio))
                _, mask = torch.topk(diff_loss, int(len(diff_loss)*ratio), largest=False)
            elif method=='random':
                newE =self.train_edge_index.size(1)
                mask = torch.randperm(newE)[:int(newE*ratio)]
            elif method == 'naive':
                degrees = scatter(torch.ones(self.train_edge_index.size(1)), self.train_edge_index[0])
                thold = (degrees.max()*ratio).long()
                prunemask = (degrees>thold).nonzero()
                # print(prunemask)
                mymask = torch.ones(self.train_edge_index.size(1))
                taway = []
                for pru in prunemask:
                    p_eid = (self.train_edge_index[0]==pru).nonzero().squeeze()
                    p = p_eid[torch.randperm(p_eid.size(0))][thold:]
                    taway.append(p)
                nonmask = torch.cat(taway)
                mask = torch.ones(self.train_edge_index.size(1), dtype=torch.bool)
                mask[nonmask] = False
                    
                    # print(mask.size())
            # mask = (diff_loss <= threshold)
            # self.train_edge_index = self.train_edge_index[:,mask]
            # edge_index = torch.cat([self.train_edge_index,self.rest_edge_index], dim=1)
            # self.data.edge_index = edge_index
            self.train_e_idx = self.train_e_idx[mask]
            self.train_edge_index = self.train_edge_index[:, mask]
            # print('train', self.train_edge_index.size())
            self.data.edge_attr = self.data.edge_attr[torch.cat([self.train_e_idx, self.rest_e_idx])]
            self.data.edge_index = self.data.edge_index[:,torch.cat([self.train_e_idx, self.rest_e_idx])]
            # print(self.data.edge_attr.size(), self.data.edge_index.size())
            self.train_e_idx = torch.arange(self.train_e_idx.size(0))
            self.rest_e_idx = torch.arange(self.train_e_idx.size(0),self.train_e_idx.size(0) + self.rest_e_idx.size(0))
            # print(len(self.train_e_idx),len(self.rest_e_idx), self.train_edge_index.size(),self.data.edge_index.size())
            self.E = self.data.num_edges
            self.adj = SparseTensor(
                row=self.data.edge_index[0], col=self.data.edge_index[1],
                value=torch.arange(self.E, device=self.data.edge_index.device),
                sparse_sizes=(self.N, self.N)) 







    # def prune(self, loss, ratio=0, naive=False):
    #     if naive == False:
    #         diff_loss = torch.abs(loss[self.train_edge_index[0]] - loss[self.train_edge_index[1]])
    #         print(diff_loss.nonzero().size())
    #         # print(int(len(diff_loss)*ratio))
    #         # print(diff_loss)
    #         _, mask = torch.topk(diff_loss, int(diff_loss.size(0)*ratio), largest=False)
    #     else:
    #         newE =self.train_edge_index.size(1)
    #         mask = torch.randperm(newE)[:int(newE*ratio)]
    #     # print(mask.size())
    #     # mask = (diff_loss <= threshold)
    #     # self.train_edge_index = self.train_edge_index[:,mask]
    #     # edge_index = torch.cat([self.train_edge_index,self.rest_edge_index], dim=1)
    #     # self.data.edge_index = edge_index
    #     self.train_e_idx = self.train_e_idx[mask]
    #     self.train_edge_index = self.train_edge_index[:, mask]
    #     # print('train', self.train_edge_index.size())
    #     self.data.edge_attr = self.data.edge_attr[torch.cat([self.train_e_idx, self.rest_e_idx])]
    #     self.data.edge_index = self.data.edge_index[:,torch.cat([self.train_e_idx, self.rest_e_idx])]
    #     # print(self.data.edge_attr.size(), self.data.edge_index.size())
    #     self.train_e_idx = torch.arange(self.train_e_idx.size(0))
    #     self.rest_e_idx = torch.arange(self.train_e_idx.size(0),self.train_e_idx.size(0) + self.rest_e_idx.size(0))
    #     # print(len(self.train_e_idx),len(self.rest_e_idx), self.train_edge_index.size(),self.data.edge_index.size())
    #     self.E = self.data.num_edges
    #     self.adj = SparseTensor(
    #         row=self.data.edge_index[0], col=self.data.edge_index[1],
    #         value=torch.arange(self.E, device=self.data.edge_index.device),
    #         sparse_sizes=(self.N, self.N)) 


    def __collate__(self, data_list):
        assert len(data_list) == 1
        node_idx, adj = data_list[0]

        data = self.data.__class__()
        data.num_nodes = node_idx.size(0)
        data.node_idx = node_idx
        row, col, edge_idx = adj.coo()
        data.edge_index = torch.stack([row, col], dim=0)

        for key, item in self.data:
            if item.size(0) == self.N:
                data[key] = item[node_idx]
            elif item.size(0) == self.E and key != 'edge_index':
                data[key] = item[edge_idx]
            elif key!= 'edge_index':
                data[key] = item

        # if self.sample_coverage > 0:
        #     data.node_norm = self.node_norm[node_idx]
        #     data.edge_norm = self.edge_norm[edge_idx]

        return data
Ejemplo n.º 11
0
class MyNeighborSampler(NeighborSampler):

    def __init__(self, data, size, num_hops,
                 batch_size=1, shuffle=False, drop_last=False, bipartite=True,
                 add_self_loops=False, flow='source_to_target',
                 use_negative_sampling=False, neg_sample_ratio=None):

        self.N = N = data.num_nodes
        self.E = data.num_edges
        self.adj = SparseTensor(
            row=data.edge_index[0], col=data.edge_index[1],
            value=torch.arange(self.E, device=data.edge_index.device),
            sparse_sizes=(N, N))

        if use_negative_sampling:
            assert neg_sample_ratio is not None
            assert neg_sample_ratio > 0.0

        self.use_negative_sampling = use_negative_sampling
        self.num_proc = mp.cpu_count()
        self.neg_sample_ratio = neg_sample_ratio
        super().__init__(data, size, num_hops, batch_size, shuffle, drop_last, bipartite, add_self_loops, flow)

    def __produce_subgraph__(self, b_id):
        r"""Produces a :obj:`Data` object holding the subgraph data for a given
        mini-batch :obj:`b_id`."""

        n_ids = [b_id]
        e_ids = []
        edge_indices = []

        for l in range(self.num_hops):
            e_id = neighbor_sampler(n_ids[-1], self.cumdeg, self.size[l])
            n_id = self.edge_index_j.index_select(0, e_id)
            n_id = n_id.unique(sorted=False)
            n_ids.append(n_id)
            e_ids.append(self.e_assoc.index_select(0, e_id))
            edge_index = self.data.edge_index.index_select(1, e_ids[-1])
            edge_indices.append(edge_index)

        n_id = torch.unique(torch.cat(n_ids, dim=0), sorted=False)
        self.tmp[n_id] = torch.arange(n_id.size(0))
        e_id = torch.cat(e_ids, dim=0)
        edge_index = self.tmp[torch.cat(edge_indices, dim=1)]

        num_nodes = n_id.size(0)
        idx = edge_index[0] * num_nodes + edge_index[1]
        idx, inv = idx.unique(sorted=False, return_inverse=True)
        edge_index = torch.stack([idx / num_nodes, idx % num_nodes], dim=0)
        e_id = e_id.new_zeros(edge_index.size(1)).scatter_(0, inv, e_id)

        # n_id: original ID of nodes in the whole sub-graph.
        # b_id: original ID of nodes in the training graph.
        # sub_b_id: sampled ID of nodes in the training graph.

        # Get full-subgraph for negative sampling.
        # Will be deleted at __call__.
        if self.use_negative_sampling:
            adj, _ = self.adj.saint_subgraph(n_id)
            row, col, edge_idx = adj.coo()
            full_edge_index = torch.stack([row, col], dim=0)
        else:
            full_edge_index = None

        return Data(edge_index=edge_index, e_id=e_id, n_id=n_id, b_id=b_id,
                    sub_b_id=self.tmp[b_id], full_edge_index=full_edge_index, num_nodes=num_nodes)

    def __call__(self, subset=None):
        r"""Returns a generator of :obj:`DataFlow` that iterates over the nodes
        in :obj:`subset` in a mini-batch fashion.

        Args:
            subset (LongTensor or BoolTensor, optional): The initial nodes to
                propagate messages to. If set to :obj:`None`, will iterate over
                all nodes in the graph. (default: :obj:`None`)
        """
        if self.bipartite:
            produce = self.__produce_bipartite_data_flow__
        else:
            produce = self.__produce_subgraph__

        if not self.use_negative_sampling:
            for n_id in self.__get_batches__(subset):
                yield produce(n_id)
        else:
            yield from self.__call_with_negatives__(subset, produce)

    def __call_with_negatives__(self, subset, produce):
        for n_id_group in grouper(self.__get_batches__(subset), self.num_proc):
            # print("produce start ~ ", end=""); t0 = time.time()
            produced_data_list = [produce(n_id) for n_id in n_id_group]
            # print("end: {}".format(time.time() - t0))
            ns_generator = fetch_and_generate(
                iters=[(p_data.full_edge_index.numpy(),
                        p_data.n_id.size(0),
                        int(self.neg_sample_ratio * p_data.edge_index.size(1)))
                       for p_data in produced_data_list],
                func=negative_sampling_numpy,
                num_proc=self.num_proc,
            )
            # print("yield start ~ ", end=""); t0 = time.time()
            for p_data, ns in zip(produced_data_list, ns_generator):
                del p_data.full_edge_index
                p_data.neg_edge_index = torch.as_tensor(ns).long()
                yield p_data