class SAINTTrainer(SampledTrainer): @staticmethod def add_args(parser: argparse.ArgumentParser): """Add trainer-specific arguments to the parser.""" # fmt: off parser.add_argument('--sampler', default='none', type=str, help='graph samplers') parser.add_argument('--sample-coverage', default=20, type=float, help='sample coverage ratio') parser.add_argument('--size-subgraph', default=1200, type=int, help='subgraph size') parser.add_argument('--num-walks', default=50, type=int, help='number of random walks') parser.add_argument('--walk-length', default=20, type=int, help='random walk length') parser.add_argument('--size-frontier', default=20, type=int, help='frontier size in multidimensional random walks') # fmt: on @classmethod def build_trainer_from_args(cls, args): return cls(args) def __init__(self, args): super(SAINTTrainer, self).__init__(args) self.args_sampler = self.sampler_from_args(args) def sampler_from_args(self, args): args_sampler = { "sampler": args.sampler, "sample_coverage": args.sample_coverage, "size_subgraph": args.size_subgraph, "num_walks": args.num_walks, "walk_length": args.walk_length, "size_frontier": args.size_frontier, } return args_sampler def fit(self, model: SupervisedModel, dataset: Dataset): self.data = dataset.data self.data.apply(lambda x: x.to(self.device)) self.model = model.to(self.device) self.evaluator = dataset.get_evaluator() self.loss_fn = dataset.get_loss_fn() if self.args_sampler["sampler"] == "node": self.sampler = NodeSampler(self.data, self.args_sampler) elif self.args_sampler["sampler"] == "edge": self.sampler = EdgeSampler(self.data, self.args_sampler) elif self.args_sampler["sampler"] == "rw": self.sampler = RWSampler(self.data, self.args_sampler) elif self.args_sampler["sampler"] == "mrw": self.sampler = MRWSampler(self.data, self.args_sampler) self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr, weight_decay=self.weight_decay) best_model = self.train() self.model = best_model return self.model def _train_step(self): self.data = self.sampler.get_subgraph("train") self.data.apply(lambda x: x.to(self.device)) self.model.train() self.optimizer.zero_grad() self.model.node_classification_loss(self.data).backward() self.optimizer.step() def _test_step(self, split="val"): self.data = self.sampler.get_subgraph(split) self.data.apply(lambda x: x.to(self.device)) self.model.eval() masks = {"train": self.data.train_mask, "val": self.data.val_mask, "test": self.data.test_mask} with torch.no_grad(): logits = self.model.predict(self.data) # self.loss_fn(logits[val], self.data.y[val]) * self.data.norm_loss[val] loss = {key: self.loss_fn(logits[val], self.data.y[val]) for key, val in masks.items()} metric = {key: self.evaluator(logits[val], self.data.y[val]) for key, val in masks.items()} return metric, loss
class SAINTTrainer(SampledTrainer): def __init__(self, args): super(SAINTTrainer, self).__init__(args) self.args_sampler = self.sampler_from_args(args) @classmethod def build_trainer_from_args(cls, args): return cls(args) def sampler_from_args(self, args): args_sampler = { "sampler": args.sampler, "sample_coverage": args.sample_coverage, "size_subgraph": args.size_subgraph, "num_walks": args.num_walks, "walk_length": args.walk_length, "size_frontier": args.size_frontier } return args_sampler def fit(self, model: SupervisedHeterogeneousNodeClassificationModel, dataset: Dataset): self.data = dataset.data self.data.apply(lambda x: x.to(self.device)) self.model = model if self.args_sampler["sampler"] == "node": self.sampler = NodeSampler(self.data, self.args_sampler) elif self.args_sampler["sampler"] == "edge": self.sampler = EdgeSampler(self.data, self.args_sampler) elif self.args_sampler["sampler"] == "rw": self.sampler = RWSampler(self.data, self.args_sampler) elif self.args_sampler["sampler"] == "mrw": self.sampler = MRWSampler(self.data, self.args_sampler) self.optimizer = torch.optim.Adam( model.parameters(), lr=self.lr, weight_decay=self.weight_decay ) best_model = self.train() self.model = best_model return self.model def _train_step(self): self.data = self.sampler.get_subgraph("train") self.data.apply(lambda x: x.to(self.device)) self.model.train() self.optimizer.zero_grad() self.model.loss(self.data).backward() self.optimizer.step() def _test_step(self, split="val"): self.data = self.sampler.get_subgraph(split) self.data.apply(lambda x: x.to(self.device)) self.model.eval() if split == "train": mask = self.data.train_mask elif split == "val": mask = self.data.val_mask else: mask = self.data.test_mask with torch.no_grad(): logits = self.model.predict(self.data) loss = (torch.nn.NLLLoss(reduction="none")(logits[mask], self.data.y[mask]) * self.data.norm_loss[mask]).sum() pred = logits[mask].max(1)[1] acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item() return acc, loss
class SAINTTrainer(SampledTrainer): def __init__(self, args): self.device = args.device_id[0] if not args.cpu else "cpu" self.patience = args.patience self.max_epoch = args.max_epoch self.lr = args.lr self.weight_decay = args.weight_decay self.args_sampler = self.sampler_from_args(args) @staticmethod def build_trainer_from_args(args): pass def sampler_from_args(self, args): args_sampler = {} args_sampler["sampler"] = args.sampler args_sampler["sample_coverage"] = args.sample_coverage args_sampler["size_subgraph"] = args.size_subgraph args_sampler["num_walks"] = args.num_walks args_sampler["walk_length"] = args.walk_length args_sampler["size_frontier"] = args.size_frontier return args_sampler def fit(self, model: SupervisedHeterogeneousNodeClassificationModel, dataset: Dataset): self.data = dataset.data self.data.apply(lambda x: x.to(self.device)) self.model = model if self.args_sampler["sampler"] == "node": self.sampler = NodeSampler(self.data, self.args_sampler) elif self.args_sampler["sampler"] == "edge": self.sampler = EdgeSampler(self.data, self.args_sampler) elif self.args_sampler["sampler"] == "rw": self.sampler = RWSampler(self.data, self.args_sampler) elif self.args_sampler["sampler"] == "mrw": self.sampler = MRWSampler(self.data, self.args_sampler) self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr, weight_decay=self.weight_decay) epoch_iter = tqdm(range(self.max_epoch)) patience = 0 best_score = 0 best_loss = np.inf max_score = 0 min_loss = np.inf for epoch in epoch_iter: self._train_step() train_acc, _ = self._test_step(split="train") val_acc, val_loss = self._test_step(split="val") epoch_iter.set_description( f"Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}" ) if val_loss <= min_loss or val_acc >= max_score: if val_acc >= best_score: # SAINT loss is not accurate best_loss = val_loss best_score = val_acc best_model = copy.deepcopy(self.model) min_loss = np.min((min_loss, val_loss)) max_score = np.max((max_score, val_acc)) patience = 0 else: patience += 1 if patience == self.patience: self.model = best_model epoch_iter.close() break return best_model def _train_step(self): self.data = self.sampler.get_subgraph("train") self.model.train() self.optimizer.zero_grad() self.model.loss(self.data).backward() self.optimizer.step() def _test_step(self, split="val"): self.data = self.sampler.get_subgraph(split) self.model.eval() if split == "train": mask = self.data.train_mask elif split == "val": mask = self.data.val_mask else: mask = self.data.test_mask logits = self.model.predict(self.data) loss = (torch.nn.NLLLoss(reduction='none')(logits[mask], self.data.y[mask]) * self.data.norm_loss[mask]).sum() pred = logits[mask].max(1)[1] acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item() return acc, loss