class NodeDataLoader: """PyTorch dataloader for batch-iterating over a set of nodes, generating the list of blocks as computation dependency of the said minibatch. Parameters ---------- g : DGLGraph The graph. nids : Tensor or dict[ntype, Tensor] The node set to compute outputs. block_sampler : dgl.dataloading.BlockSampler The neighborhood sampler. kwargs : dict Arguments being passed to :py:class:`torch.utils.data.DataLoader`. Examples -------- To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on a homogeneous graph where each node takes messages from all neighbors (assume the backend is PyTorch): >>> sampler = dgl.dataloading.NeighborSampler([None, None, None]) >>> dataloader = dgl.dataloading.NodeDataLoader( ... g, train_nid, sampler, ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for input_nodes, output_nodes, blocks in dataloader: ... train_on(input_nodes, output_nodes, blocks) """ collator_arglist = inspect.getfullargspec(NodeCollator).args def __init__(self, g, nids, block_sampler, **kwargs): collator_kwargs = {} dataloader_kwargs = {} for k, v in kwargs.items(): if k in self.collator_arglist: collator_kwargs[k] = v else: dataloader_kwargs[k] = v self.collator = NodeCollator(g, nids, block_sampler, **collator_kwargs) if isinstance(g, DistGraph): _remove_kwargs_dist(dataloader_kwargs) self.dataloader = DistDataLoader(self.collator.dataset, collate_fn=self.collator.collate, **dataloader_kwargs) else: self.dataloader = DataLoader(self.collator.dataset, collate_fn=self.collator.collate, **dataloader_kwargs) def __next__(self): return self.dataloader.__next() def __iter__(self): return self.dataloader.__iter__()