Exemple #1
0
class RandomClusterTrainer(SampledTrainer):
    @staticmethod
    def add_args(parser):
        # fmt: off
        SampledTrainer.add_args(parser)
        parser.add_argument("--n-cluster", type=int, default=10)
        # fmt: on

    def __init__(self, args):
        super(RandomClusterTrainer, self).__init__(args)
        self.patience = args.patience // args.eval_step
        self.n_cluster = args.n_cluster
        self.eval_step = args.eval_step
        self.data, self.optimizer, self.evaluator, self.loss_fn = None, None, None, None

    def fit(self, model, dataset):
        self.model = model.to(self.device)
        self.data = dataset[0]
        self.data.add_remaining_self_loops()

        self.loss_fn = dataset.get_loss_fn()
        self.evaluator = dataset.get_evaluator()

        settings = dict(num_workers=self.num_workers, persistent_workers=True, pin_memory=True)

        if torch.__version__.split("+")[0] < "1.7.1":
            settings.pop("persistent_workers")

        self.train_loader = ClusteredLoader(dataset=dataset, n_cluster=self.n_cluster, method="random", **settings)

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        best_model = self.train()
        self.model = best_model
        metric, loss = self._test_step()
        return dict(Acc=metric["test"], ValAcc=metric["val"])

    def _train_step(self):
        self.model.train()
        self.data.train()
        self.train_loader.shuffle()

        for batch in self.train_loader:
            self.optimizer.zero_grad()
            batch = batch.to(self.device)
            loss_n = self.model.node_classification_loss(batch)
            loss_n.backward()
            self.optimizer.step()

    def _test_step(self, split="val"):
        self.model.eval()
        self.data.eval()
        self.model = self.model.to("cpu")
        data = self.data
        self.model = self.model.cpu()
        masks = {"train": self.data.train_mask, "val": self.data.val_mask, "test": self.data.test_mask}
        with torch.no_grad():
            logits = self.model.predict(data)
        loss = {key: self.loss_fn(logits[val], self.data.y[val]) for key, val in masks.items()}
        metric = {key: self.evaluator(logits[val], self.data.y[val]) for key, val in masks.items()}
        return metric, loss
Exemple #2
0
    def fit(self, model, dataset):
        self.data = dataset[0]
        self.data.add_remaining_self_loops()
        self.model = model.to(self.device)
        self.evaluator = dataset.get_evaluator()
        self.loss_fn = dataset.get_loss_fn()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        settings = dict(
            batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True, pin_memory=True
        )

        if torch.__version__.split("+")[0] < "1.7.1":
            settings.pop("persistent_workers")

        self.data.train()
        self.train_loader = ClusteredLoader(
            dataset,
            self.n_cluster,
            method="metis",
            **settings,
        )
        best_model = self.train()
        self.model = best_model
        metric, loss = self._test_step()

        return dict(Acc=metric["test"], ValAcc=metric["val"])
Exemple #3
0
    def build_dataloader(self, dataset, rank):
        if self.device != 0:
            dist.barrier()
        data = dataset[0]
        train_dataset = ClusteredDataset(dataset, self.n_cluster,
                                         self.batch_size)
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset, num_replicas=self.args.world_size, rank=rank)

        settings = dict(
            num_workers=self.num_workers,
            persistent_workers=True,
            pin_memory=True,
            batch_size=self.args.batch_size,
        )

        if torch.__version__.split("+")[0] < "1.7.1":
            settings.pop("persistent_workers")

        data.train()
        train_loader = ClusteredLoader(dataset=train_dataset,
                                       n_cluster=self.args.n_cluster,
                                       method="metis",
                                       sampler=train_sampler,
                                       **settings)
        if self.device == 0:
            dist.barrier()

        settings["batch_size"] *= 5
        data.eval()
        test_loader = NeighborSampler(dataset=dataset, sizes=[-1], **settings)
        val_loader = test_loader
        return train_dataset, (train_loader, val_loader, test_loader)
Exemple #4
0
 def train_wrapper(self):
     self.dataset.data.train()
     return ClusteredLoader(
         self.cluster_dataset,
         method=self.method,
         batch_size=self.batch_size,
         shuffle=True,
         n_cluster=self.n_cluster,
         # persistent_workers=True,
         num_workers=0,
     )
Exemple #5
0
    def fit(self, model, dataset):
        self.data = dataset[0]
        self.model = model.to(self.device)
        self.evaluator = dataset.get_evaluator()
        self.loss_fn = dataset.get_loss_fn()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        self.train_loader = ClusteredLoader(self.data, self.n_cluster, batch_size=self.batch_size, shuffle=True)
        best_model = self.train()
        self.model = best_model
        metric, loss = self._test_step()

        return dict(Acc=metric["test"], ValAcc=metric["val"])
Exemple #6
0
class ClusterGCNTrainer(SampledTrainer):
    @staticmethod
    def add_args(parser: argparse.ArgumentParser):
        """Add trainer-specific arguments to the parser."""
        # fmt: off
        SampledTrainer.add_args(parser)
        parser.add_argument("--n-cluster", type=int, default=1000)
        parser.add_argument("--batch-size", type=int, default=20)
        # fmt: on

    @staticmethod
    def get_args4sampler(args):
        args4sampler = {
            "method": "metis",
            "n_cluster": args.n_cluster,
        }
        return args4sampler

    def __init__(self, args):
        super(ClusterGCNTrainer, self).__init__(args)
        self.n_cluster = args.n_cluster
        self.batch_size = args.batch_size

    def fit(self, model, dataset):
        self.data = dataset[0]
        self.data.add_remaining_self_loops()
        self.model = model.to(self.device)
        self.evaluator = dataset.get_evaluator()
        self.loss_fn = dataset.get_loss_fn()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        settings = dict(
            batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True, pin_memory=True
        )

        if torch.__version__.split("+")[0] < "1.7.1":
            settings.pop("persistent_workers")

        self.data.train()
        self.train_loader = ClusteredLoader(
            dataset,
            self.n_cluster,
            method="metis",
            **settings,
        )
        best_model = self.train()
        self.model = best_model
        metric, loss = self._test_step()

        return dict(Acc=metric["test"], ValAcc=metric["val"])

    def _train_step(self):
        self.model.train()
        self.data.train()
        self.train_loader.shuffle()
        total_loss = 0
        for batch in self.train_loader:
            self.optimizer.zero_grad()
            batch = batch.to(self.device)
            loss = self.model.node_classification_loss(batch)
            loss.backward()
            total_loss += loss.item()
            self.optimizer.step()

    def _test_step(self, split="val"):
        self.model.eval()
        self.data.eval()
        data = self.data
        self.model = self.model.cpu()
        masks = {"train": self.data.train_mask, "val": self.data.val_mask, "test": self.data.test_mask}
        with torch.no_grad():
            logits = self.model(data)
        loss = {key: self.loss_fn(logits[val], self.data.y[val]) for key, val in masks.items()}
        metric = {key: self.evaluator(logits[val], self.data.y[val]) for key, val in masks.items()}
        return metric, loss