def build_dataloader(self, dataset, rank): data = dataset[0] train_dataset = NeighborSamplerDataset(dataset, self.sample_size, self.batch_size, self.data.train_mask) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=self.args.world_size, rank=rank) settings = dict( num_workers=self.num_workers, persistent_workers=True, pin_memory=True, batch_size=self.args.batch_size, ) if torch.__version__.split("+")[0] < "1.7.1": settings.pop("persistent_workers") data.train() train_loader = NeighborSampler(dataset=train_dataset, sizes=self.sample_size, sampler=train_sampler, **settings) settings["batch_size"] *= 5 data.eval() test_loader = NeighborSampler(dataset=dataset, sizes=[-1], **settings) val_loader = test_loader return train_dataset, (train_loader, val_loader, test_loader)
def fit(self, model, dataset): self.data = dataset[0] self.data.edge_index, _ = add_remaining_self_loops(self.data.edge_index) if hasattr(self.data, "edge_index_train"): self.data.edge_index_train, _ = add_remaining_self_loops(self.data.edge_index_train) self.evaluator = dataset.get_evaluator() self.loss_fn = dataset.get_loss_fn() self.train_loader = NeighborSampler( data=self.data, mask=self.data.train_mask, sizes=self.sample_size, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, ) self.test_loader = NeighborSampler( data=self.data, mask=None, sizes=[-1], batch_size=self.batch_size, shuffle=False ) self.model = model.to(self.device) self.model.set_data_device(self.device) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) best_model = self.train() self.model = best_model acc, loss = self._test_step() return dict(Acc=acc["test"], ValAcc=acc["val"])
def build_dataloader(self, dataset, rank): train_dataset = SAINTDataset(dataset, self.sampler_from_args(self.args)) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=self.args.world_size, rank=rank) settings = dict( num_workers=self.num_workers, persistent_workers=True, pin_memory=True, batch_size=self.args.batch_size, ) if torch.__version__.split("+")[0] < "1.7.1": settings.pop("persistent_workers") train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=1, shuffle=False, num_workers=4, persistent_workers=True, pin_memory=True, sampler=train_sampler, collate_fn=batcher, ) test_loader = NeighborSampler( dataset=dataset, sizes=[-1], **settings, ) val_loader = test_loader return train_dataset, (train_loader, val_loader, test_loader)
def build_dataloader(self, dataset, rank): if self.device != 0: dist.barrier() data = dataset[0] train_dataset = ClusteredDataset(dataset, self.n_cluster, self.batch_size) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=self.args.world_size, rank=rank) settings = dict( num_workers=self.num_workers, persistent_workers=True, pin_memory=True, batch_size=self.args.batch_size, ) if torch.__version__.split("+")[0] < "1.7.1": settings.pop("persistent_workers") data.train() train_loader = ClusteredLoader(dataset=train_dataset, n_cluster=self.args.n_cluster, method="metis", sampler=train_sampler, **settings) if self.device == 0: dist.barrier() settings["batch_size"] *= 5 data.eval() test_loader = NeighborSampler(dataset=dataset, sizes=[-1], **settings) val_loader = test_loader return train_dataset, (train_loader, val_loader, test_loader)
def train_wrapper(self): self.dataset.data.train() return NeighborSampler( dataset=self.train_dataset, mask=self.dataset.data.train_mask, sizes=self.sample_size, num_workers=4, shuffle=False, batch_size=self.batch_size, )
def fit(self, model, dataset): self.data = dataset[0] self.data.add_remaining_self_loops() self.evaluator = dataset.get_evaluator() self.loss_fn = dataset.get_loss_fn() settings = dict( batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, persistent_workers=True, pin_memory=True, ) if torch.__version__.split("+")[0] < "1.7.1": settings.pop("persistent_workers") self.data.train() self.train_loader = NeighborSampler( dataset=dataset, mask=self.data.train_mask, sizes=self.sample_size, **settings, ) settings["batch_size"] *= 5 self.data.eval() self.test_loader = NeighborSampler( dataset=dataset, mask=None, sizes=[-1], **settings, ) self.model = model.to(self.device) self.model.set_data_device(self.device) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) best_model = self.train() self.model = best_model acc, loss = self._test_step() return dict(Acc=acc["test"], ValAcc=acc["val"])
def val_wrapper(self): self.dataset.data.eval() return NeighborSampler( dataset=self.val_dataset, mask=self.dataset.data.val_mask, sizes=self.sample_size, batch_size=self.batch_size * 2, shuffle=False, num_workers=4, )
def test_wrapper(self): return ( self.dataset, NeighborSampler( dataset=self.test_dataset, mask=None, sizes=[-1], batch_size=self.batch_size * 2, shuffle=False, num_workers=4, ), )
def fit(self, model, dataset): self.data = Data.from_pyg_data(dataset[0]) self.train_loader = NeighborSampler(data=self.data, mask=self.data.train_mask, sizes=self.sample_size, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True) self.test_loader = NeighborSampler(data=self.data, mask=None, sizes=[-1], batch_size=self.batch_size, shuffle=False) self.model = model.to(self.device) self.model.set_data_device(self.device) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) best_model = self.train() self.model = best_model acc, loss = self._test_step() return dict(Acc=acc["test"], ValAcc=acc["val"])
def _test_step(self, split="val"): if split == "test": if torch.__version__.split("+")[0] < "1.7.1": self.test_loader = NeighborSampler( dataset=self.dataset, sizes=[-1], batch_size=self.batch_size * 10, num_workers=self.num_workers, shuffle=False, pin_memory=True, ) else: self.test_loader = NeighborSampler( dataset=self.dataset, sizes=[-1], batch_size=self.batch_size * 10, num_workers=self.num_workers, shuffle=False, persistent_workers=True, pin_memory=True, ) return super(DistributedNeighborSamplerTrainer, self)._test_step()
class NeighborSamplingTrainer(SampledTrainer): model: torch.nn.Module @staticmethod def add_args(parser: argparse.ArgumentParser): """Add trainer-specific arguments to the parser.""" # fmt: off SampledTrainer.add_args(parser) # fmt: on def __init__(self, args): super(NeighborSamplingTrainer, self).__init__(args) self.hidden_size = args.hidden_size self.sample_size = args.sample_size def fit(self, model, dataset): self.data = dataset[0] self.data.add_remaining_self_loops() self.evaluator = dataset.get_evaluator() self.loss_fn = dataset.get_loss_fn() settings = dict( batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, persistent_workers=True, pin_memory=True, ) if torch.__version__.split("+")[0] < "1.7.1": settings.pop("persistent_workers") self.data.train() self.train_loader = NeighborSampler( dataset=dataset, mask=self.data.train_mask, sizes=self.sample_size, **settings, ) settings["batch_size"] *= 5 self.data.eval() self.test_loader = NeighborSampler( dataset=dataset, mask=None, sizes=[-1], **settings, ) self.model = model.to(self.device) self.model.set_data_device(self.device) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) best_model = self.train() self.model = best_model acc, loss = self._test_step() return dict(Acc=acc["test"], ValAcc=acc["val"]) def _train_step(self): self.data.train() self.model.train() self.train_loader.shuffle() x_all = self.data.x.to(self.device) y_all = self.data.y.to(self.device) for target_id, n_id, adjs in self.train_loader: self.optimizer.zero_grad() n_id = n_id.to(x_all.device) target_id = target_id.to(y_all.device) x_src = x_all[n_id].to(self.device) y = y_all[target_id].to(self.device) loss = self.model.node_classification_loss(x_src, adjs, y) loss.backward() self.optimizer.step() def _test_step(self, split="val"): self.model.eval() self.data.eval() masks = {"train": self.data.train_mask, "val": self.data.val_mask, "test": self.data.test_mask} with torch.no_grad(): logits = self.model.inference(self.data.x, self.test_loader) loss = {key: self.loss_fn(logits[val], self.data.y[val]) for key, val in masks.items()} acc = {key: self.evaluator(logits[val], self.data.y[val]) for key, val in masks.items()} return acc, loss @classmethod def build_trainer_from_args(cls, args): return cls(args)