示例#1
0
    def build_dataloader(self, dataset, rank):
        data = dataset[0]
        train_dataset = NeighborSamplerDataset(dataset, self.sample_size,
                                               self.batch_size,
                                               self.data.train_mask)
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset, num_replicas=self.args.world_size, rank=rank)

        settings = dict(
            num_workers=self.num_workers,
            persistent_workers=True,
            pin_memory=True,
            batch_size=self.args.batch_size,
        )

        if torch.__version__.split("+")[0] < "1.7.1":
            settings.pop("persistent_workers")

        data.train()
        train_loader = NeighborSampler(dataset=train_dataset,
                                       sizes=self.sample_size,
                                       sampler=train_sampler,
                                       **settings)

        settings["batch_size"] *= 5
        data.eval()
        test_loader = NeighborSampler(dataset=dataset, sizes=[-1], **settings)
        val_loader = test_loader
        return train_dataset, (train_loader, val_loader, test_loader)
示例#2
0
    def fit(self, model, dataset):
        self.data = dataset[0]
        self.data.edge_index, _ = add_remaining_self_loops(self.data.edge_index)
        if hasattr(self.data, "edge_index_train"):
            self.data.edge_index_train, _ = add_remaining_self_loops(self.data.edge_index_train)
        self.evaluator = dataset.get_evaluator()
        self.loss_fn = dataset.get_loss_fn()

        self.train_loader = NeighborSampler(
            data=self.data,
            mask=self.data.train_mask,
            sizes=self.sample_size,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
        )
        self.test_loader = NeighborSampler(
            data=self.data, mask=None, sizes=[-1], batch_size=self.batch_size, shuffle=False
        )
        self.model = model.to(self.device)
        self.model.set_data_device(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        best_model = self.train()
        self.model = best_model
        acc, loss = self._test_step()
        return dict(Acc=acc["test"], ValAcc=acc["val"])
示例#3
0
    def build_dataloader(self, dataset, rank):
        train_dataset = SAINTDataset(dataset,
                                     self.sampler_from_args(self.args))
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset, num_replicas=self.args.world_size, rank=rank)

        settings = dict(
            num_workers=self.num_workers,
            persistent_workers=True,
            pin_memory=True,
            batch_size=self.args.batch_size,
        )

        if torch.__version__.split("+")[0] < "1.7.1":
            settings.pop("persistent_workers")

        train_loader = torch.utils.data.DataLoader(
            dataset=train_dataset,
            batch_size=1,
            shuffle=False,
            num_workers=4,
            persistent_workers=True,
            pin_memory=True,
            sampler=train_sampler,
            collate_fn=batcher,
        )

        test_loader = NeighborSampler(
            dataset=dataset,
            sizes=[-1],
            **settings,
        )
        val_loader = test_loader
        return train_dataset, (train_loader, val_loader, test_loader)
示例#4
0
    def build_dataloader(self, dataset, rank):
        if self.device != 0:
            dist.barrier()
        data = dataset[0]
        train_dataset = ClusteredDataset(dataset, self.n_cluster,
                                         self.batch_size)
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset, num_replicas=self.args.world_size, rank=rank)

        settings = dict(
            num_workers=self.num_workers,
            persistent_workers=True,
            pin_memory=True,
            batch_size=self.args.batch_size,
        )

        if torch.__version__.split("+")[0] < "1.7.1":
            settings.pop("persistent_workers")

        data.train()
        train_loader = ClusteredLoader(dataset=train_dataset,
                                       n_cluster=self.args.n_cluster,
                                       method="metis",
                                       sampler=train_sampler,
                                       **settings)
        if self.device == 0:
            dist.barrier()

        settings["batch_size"] *= 5
        data.eval()
        test_loader = NeighborSampler(dataset=dataset, sizes=[-1], **settings)
        val_loader = test_loader
        return train_dataset, (train_loader, val_loader, test_loader)
示例#5
0
 def train_wrapper(self):
     self.dataset.data.train()
     return NeighborSampler(
         dataset=self.train_dataset,
         mask=self.dataset.data.train_mask,
         sizes=self.sample_size,
         num_workers=4,
         shuffle=False,
         batch_size=self.batch_size,
     )
示例#6
0
    def fit(self, model, dataset):
        self.data = dataset[0]
        self.data.add_remaining_self_loops()
        self.evaluator = dataset.get_evaluator()
        self.loss_fn = dataset.get_loss_fn()

        settings = dict(
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            persistent_workers=True,
            pin_memory=True,
        )

        if torch.__version__.split("+")[0] < "1.7.1":
            settings.pop("persistent_workers")

        self.data.train()
        self.train_loader = NeighborSampler(
            dataset=dataset,
            mask=self.data.train_mask,
            sizes=self.sample_size,
            **settings,
        )

        settings["batch_size"] *= 5
        self.data.eval()
        self.test_loader = NeighborSampler(
            dataset=dataset,
            mask=None,
            sizes=[-1],
            **settings,
        )
        self.model = model.to(self.device)
        self.model.set_data_device(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.lr,
                                          weight_decay=self.weight_decay)
        best_model = self.train()
        self.model = best_model
        acc, loss = self._test_step()
        return dict(Acc=acc["test"], ValAcc=acc["val"])
示例#7
0
    def val_wrapper(self):
        self.dataset.data.eval()

        return NeighborSampler(
            dataset=self.val_dataset,
            mask=self.dataset.data.val_mask,
            sizes=self.sample_size,
            batch_size=self.batch_size * 2,
            shuffle=False,
            num_workers=4,
        )
示例#8
0
 def test_wrapper(self):
     return (
         self.dataset,
         NeighborSampler(
             dataset=self.test_dataset,
             mask=None,
             sizes=[-1],
             batch_size=self.batch_size * 2,
             shuffle=False,
             num_workers=4,
         ),
     )
示例#9
0
 def fit(self, model, dataset):
     self.data = Data.from_pyg_data(dataset[0])
     self.train_loader = NeighborSampler(data=self.data,
                                         mask=self.data.train_mask,
                                         sizes=self.sample_size,
                                         batch_size=self.batch_size,
                                         num_workers=self.num_workers,
                                         shuffle=True)
     self.test_loader = NeighborSampler(data=self.data,
                                        mask=None,
                                        sizes=[-1],
                                        batch_size=self.batch_size,
                                        shuffle=False)
     self.model = model.to(self.device)
     self.model.set_data_device(self.device)
     self.optimizer = torch.optim.Adam(self.model.parameters(),
                                       lr=self.lr,
                                       weight_decay=self.weight_decay)
     best_model = self.train()
     self.model = best_model
     acc, loss = self._test_step()
     return dict(Acc=acc["test"], ValAcc=acc["val"])
示例#10
0
 def _test_step(self, split="val"):
     if split == "test":
         if torch.__version__.split("+")[0] < "1.7.1":
             self.test_loader = NeighborSampler(
                 dataset=self.dataset,
                 sizes=[-1],
                 batch_size=self.batch_size * 10,
                 num_workers=self.num_workers,
                 shuffle=False,
                 pin_memory=True,
             )
         else:
             self.test_loader = NeighborSampler(
                 dataset=self.dataset,
                 sizes=[-1],
                 batch_size=self.batch_size * 10,
                 num_workers=self.num_workers,
                 shuffle=False,
                 persistent_workers=True,
                 pin_memory=True,
             )
     return super(DistributedNeighborSamplerTrainer, self)._test_step()
示例#11
0
class NeighborSamplingTrainer(SampledTrainer):
    model: torch.nn.Module

    @staticmethod
    def add_args(parser: argparse.ArgumentParser):
        """Add trainer-specific arguments to the parser."""
        # fmt: off
        SampledTrainer.add_args(parser)
        # fmt: on

    def __init__(self, args):
        super(NeighborSamplingTrainer, self).__init__(args)
        self.hidden_size = args.hidden_size
        self.sample_size = args.sample_size

    def fit(self, model, dataset):
        self.data = dataset[0]
        self.data.add_remaining_self_loops()
        self.evaluator = dataset.get_evaluator()
        self.loss_fn = dataset.get_loss_fn()

        settings = dict(
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            persistent_workers=True,
            pin_memory=True,
        )

        if torch.__version__.split("+")[0] < "1.7.1":
            settings.pop("persistent_workers")

        self.data.train()
        self.train_loader = NeighborSampler(
            dataset=dataset,
            mask=self.data.train_mask,
            sizes=self.sample_size,
            **settings,
        )

        settings["batch_size"] *= 5
        self.data.eval()
        self.test_loader = NeighborSampler(
            dataset=dataset,
            mask=None,
            sizes=[-1],
            **settings,
        )
        self.model = model.to(self.device)
        self.model.set_data_device(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        best_model = self.train()
        self.model = best_model
        acc, loss = self._test_step()
        return dict(Acc=acc["test"], ValAcc=acc["val"])

    def _train_step(self):
        self.data.train()
        self.model.train()
        self.train_loader.shuffle()

        x_all = self.data.x.to(self.device)
        y_all = self.data.y.to(self.device)

        for target_id, n_id, adjs in self.train_loader:
            self.optimizer.zero_grad()
            n_id = n_id.to(x_all.device)
            target_id = target_id.to(y_all.device)
            x_src = x_all[n_id].to(self.device)

            y = y_all[target_id].to(self.device)
            loss = self.model.node_classification_loss(x_src, adjs, y)
            loss.backward()
            self.optimizer.step()

    def _test_step(self, split="val"):
        self.model.eval()
        self.data.eval()
        masks = {"train": self.data.train_mask, "val": self.data.val_mask, "test": self.data.test_mask}
        with torch.no_grad():
            logits = self.model.inference(self.data.x, self.test_loader)

        loss = {key: self.loss_fn(logits[val], self.data.y[val]) for key, val in masks.items()}
        acc = {key: self.evaluator(logits[val], self.data.y[val]) for key, val in masks.items()}
        return acc, loss

    @classmethod
    def build_trainer_from_args(cls, args):
        return cls(args)