class RandomClusterTrainer(SampledTrainer): @staticmethod def add_args(parser): # fmt: off SampledTrainer.add_args(parser) parser.add_argument("--n-cluster", type=int, default=10) # fmt: on def __init__(self, args): super(RandomClusterTrainer, self).__init__(args) self.patience = args.patience // args.eval_step self.n_cluster = args.n_cluster self.eval_step = args.eval_step self.data, self.optimizer, self.evaluator, self.loss_fn = None, None, None, None def fit(self, model, dataset): self.model = model.to(self.device) self.data = dataset[0] self.data.add_remaining_self_loops() self.loss_fn = dataset.get_loss_fn() self.evaluator = dataset.get_evaluator() settings = dict(num_workers=self.num_workers, persistent_workers=True, pin_memory=True) if torch.__version__.split("+")[0] < "1.7.1": settings.pop("persistent_workers") self.train_loader = ClusteredLoader(dataset=dataset, n_cluster=self.n_cluster, method="random", **settings) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) best_model = self.train() self.model = best_model metric, loss = self._test_step() return dict(Acc=metric["test"], ValAcc=metric["val"]) def _train_step(self): self.model.train() self.data.train() self.train_loader.shuffle() for batch in self.train_loader: self.optimizer.zero_grad() batch = batch.to(self.device) loss_n = self.model.node_classification_loss(batch) loss_n.backward() self.optimizer.step() def _test_step(self, split="val"): self.model.eval() self.data.eval() self.model = self.model.to("cpu") data = self.data self.model = self.model.cpu() masks = {"train": self.data.train_mask, "val": self.data.val_mask, "test": self.data.test_mask} with torch.no_grad(): logits = self.model.predict(data) loss = {key: self.loss_fn(logits[val], self.data.y[val]) for key, val in masks.items()} metric = {key: self.evaluator(logits[val], self.data.y[val]) for key, val in masks.items()} return metric, loss
def fit(self, model, dataset): self.data = dataset[0] self.data.add_remaining_self_loops() self.model = model.to(self.device) self.evaluator = dataset.get_evaluator() self.loss_fn = dataset.get_loss_fn() self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) settings = dict( batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True, pin_memory=True ) if torch.__version__.split("+")[0] < "1.7.1": settings.pop("persistent_workers") self.data.train() self.train_loader = ClusteredLoader( dataset, self.n_cluster, method="metis", **settings, ) best_model = self.train() self.model = best_model metric, loss = self._test_step() return dict(Acc=metric["test"], ValAcc=metric["val"])
def build_dataloader(self, dataset, rank): if self.device != 0: dist.barrier() data = dataset[0] train_dataset = ClusteredDataset(dataset, self.n_cluster, self.batch_size) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=self.args.world_size, rank=rank) settings = dict( num_workers=self.num_workers, persistent_workers=True, pin_memory=True, batch_size=self.args.batch_size, ) if torch.__version__.split("+")[0] < "1.7.1": settings.pop("persistent_workers") data.train() train_loader = ClusteredLoader(dataset=train_dataset, n_cluster=self.args.n_cluster, method="metis", sampler=train_sampler, **settings) if self.device == 0: dist.barrier() settings["batch_size"] *= 5 data.eval() test_loader = NeighborSampler(dataset=dataset, sizes=[-1], **settings) val_loader = test_loader return train_dataset, (train_loader, val_loader, test_loader)
def train_wrapper(self): self.dataset.data.train() return ClusteredLoader( self.cluster_dataset, method=self.method, batch_size=self.batch_size, shuffle=True, n_cluster=self.n_cluster, # persistent_workers=True, num_workers=0, )
def fit(self, model, dataset): self.data = dataset[0] self.model = model.to(self.device) self.evaluator = dataset.get_evaluator() self.loss_fn = dataset.get_loss_fn() self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) self.train_loader = ClusteredLoader(self.data, self.n_cluster, batch_size=self.batch_size, shuffle=True) best_model = self.train() self.model = best_model metric, loss = self._test_step() return dict(Acc=metric["test"], ValAcc=metric["val"])
class ClusterGCNTrainer(SampledTrainer): @staticmethod def add_args(parser: argparse.ArgumentParser): """Add trainer-specific arguments to the parser.""" # fmt: off SampledTrainer.add_args(parser) parser.add_argument("--n-cluster", type=int, default=1000) parser.add_argument("--batch-size", type=int, default=20) # fmt: on @staticmethod def get_args4sampler(args): args4sampler = { "method": "metis", "n_cluster": args.n_cluster, } return args4sampler def __init__(self, args): super(ClusterGCNTrainer, self).__init__(args) self.n_cluster = args.n_cluster self.batch_size = args.batch_size def fit(self, model, dataset): self.data = dataset[0] self.data.add_remaining_self_loops() self.model = model.to(self.device) self.evaluator = dataset.get_evaluator() self.loss_fn = dataset.get_loss_fn() self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) settings = dict( batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True, pin_memory=True ) if torch.__version__.split("+")[0] < "1.7.1": settings.pop("persistent_workers") self.data.train() self.train_loader = ClusteredLoader( dataset, self.n_cluster, method="metis", **settings, ) best_model = self.train() self.model = best_model metric, loss = self._test_step() return dict(Acc=metric["test"], ValAcc=metric["val"]) def _train_step(self): self.model.train() self.data.train() self.train_loader.shuffle() total_loss = 0 for batch in self.train_loader: self.optimizer.zero_grad() batch = batch.to(self.device) loss = self.model.node_classification_loss(batch) loss.backward() total_loss += loss.item() self.optimizer.step() def _test_step(self, split="val"): self.model.eval() self.data.eval() data = self.data self.model = self.model.cpu() masks = {"train": self.data.train_mask, "val": self.data.val_mask, "test": self.data.test_mask} with torch.no_grad(): logits = self.model(data) loss = {key: self.loss_fn(logits[val], self.data.y[val]) for key, val in masks.items()} metric = {key: self.evaluator(logits[val], self.data.y[val]) for key, val in masks.items()} return metric, loss