Ejemplo n.º 1
0
class SAINTTrainer(SampledTrainer):
    @staticmethod
    def add_args(parser: argparse.ArgumentParser):
        """Add trainer-specific arguments to the parser."""
        # fmt: off
        parser.add_argument('--sampler', default='none', type=str, help='graph samplers')
        parser.add_argument('--sample-coverage', default=20, type=float, help='sample coverage ratio')
        parser.add_argument('--size-subgraph', default=1200, type=int, help='subgraph size')
        parser.add_argument('--num-walks', default=50, type=int, help='number of random walks')
        parser.add_argument('--walk-length', default=20, type=int, help='random walk length')
        parser.add_argument('--size-frontier', default=20, type=int, help='frontier size in multidimensional random walks')
        # fmt: on

    @classmethod
    def build_trainer_from_args(cls, args):
        return cls(args)

    def __init__(self, args):
        super(SAINTTrainer, self).__init__(args)
        self.args_sampler = self.sampler_from_args(args)

    def sampler_from_args(self, args):
        args_sampler = {
            "sampler": args.sampler,
            "sample_coverage": args.sample_coverage,
            "size_subgraph": args.size_subgraph,
            "num_walks": args.num_walks,
            "walk_length": args.walk_length,
            "size_frontier": args.size_frontier,
        }
        return args_sampler

    def fit(self, model: SupervisedModel, dataset: Dataset):
        self.data = dataset.data
        self.data.apply(lambda x: x.to(self.device))
        self.model = model.to(self.device)
        self.evaluator = dataset.get_evaluator()
        self.loss_fn = dataset.get_loss_fn()

        if self.args_sampler["sampler"] == "node":
            self.sampler = NodeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "edge":
            self.sampler = EdgeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "rw":
            self.sampler = RWSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "mrw":
            self.sampler = MRWSampler(self.data, self.args_sampler)

        self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        best_model = self.train()
        self.model = best_model
        return self.model

    def _train_step(self):
        self.data = self.sampler.get_subgraph("train")
        self.data.apply(lambda x: x.to(self.device))
        self.model.train()
        self.optimizer.zero_grad()
        self.model.node_classification_loss(self.data).backward()
        self.optimizer.step()

    def _test_step(self, split="val"):
        self.data = self.sampler.get_subgraph(split)
        self.data.apply(lambda x: x.to(self.device))
        self.model.eval()
        masks = {"train": self.data.train_mask, "val": self.data.val_mask, "test": self.data.test_mask}
        with torch.no_grad():
            logits = self.model.predict(self.data)
        # self.loss_fn(logits[val], self.data.y[val]) * self.data.norm_loss[val]
        loss = {key: self.loss_fn(logits[val], self.data.y[val]) for key, val in masks.items()}
        metric = {key: self.evaluator(logits[val], self.data.y[val]) for key, val in masks.items()}
        return metric, loss
Ejemplo n.º 2
0
class SAINTTrainer(SampledTrainer):
    def __init__(self, args):
        super(SAINTTrainer, self).__init__(args)
        self.args_sampler = self.sampler_from_args(args)

    @classmethod
    def build_trainer_from_args(cls, args):
        return cls(args)

    def sampler_from_args(self, args):
        args_sampler = {
            "sampler": args.sampler,
            "sample_coverage": args.sample_coverage,
            "size_subgraph": args.size_subgraph,
            "num_walks": args.num_walks,
            "walk_length": args.walk_length,
            "size_frontier": args.size_frontier
        }
        return args_sampler

    def fit(self, model: SupervisedHeterogeneousNodeClassificationModel, dataset: Dataset):
        self.data = dataset.data
        self.data.apply(lambda x: x.to(self.device))
        self.model = model
        if self.args_sampler["sampler"] == "node":
            self.sampler = NodeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "edge":
            self.sampler = EdgeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "rw":
            self.sampler = RWSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "mrw":
            self.sampler = MRWSampler(self.data, self.args_sampler)

        self.optimizer = torch.optim.Adam(
            model.parameters(), lr=self.lr, weight_decay=self.weight_decay
        )
        best_model = self.train()
        self.model = best_model
        return self.model

    def _train_step(self):
        self.data = self.sampler.get_subgraph("train")
        self.data.apply(lambda x: x.to(self.device))
        self.model.train()
        self.optimizer.zero_grad()
        self.model.loss(self.data).backward()
        self.optimizer.step()

    def _test_step(self, split="val"):
        self.data = self.sampler.get_subgraph(split)
        self.data.apply(lambda x: x.to(self.device))
        self.model.eval()
        if split == "train":
            mask = self.data.train_mask
        elif split == "val":
            mask = self.data.val_mask
        else:
            mask = self.data.test_mask

        with torch.no_grad():
            logits = self.model.predict(self.data)
            loss = (torch.nn.NLLLoss(reduction="none")(logits[mask], self.data.y[mask]) * self.data.norm_loss[mask]).sum()

        pred = logits[mask].max(1)[1]
        acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item()
        return acc, loss
Ejemplo n.º 3
0
class SAINTTrainer(SampledTrainer):
    def __init__(self, args):
        self.device = args.device_id[0] if not args.cpu else "cpu"
        self.patience = args.patience
        self.max_epoch = args.max_epoch
        self.lr = args.lr
        self.weight_decay = args.weight_decay
        self.args_sampler = self.sampler_from_args(args)

    @staticmethod
    def build_trainer_from_args(args):
        pass

    def sampler_from_args(self, args):
        args_sampler = {}
        args_sampler["sampler"] = args.sampler
        args_sampler["sample_coverage"] = args.sample_coverage
        args_sampler["size_subgraph"] = args.size_subgraph
        args_sampler["num_walks"] = args.num_walks
        args_sampler["walk_length"] = args.walk_length
        args_sampler["size_frontier"] = args.size_frontier
        return args_sampler

    def fit(self, model: SupervisedHeterogeneousNodeClassificationModel,
            dataset: Dataset):
        self.data = dataset.data
        self.data.apply(lambda x: x.to(self.device))
        self.model = model
        if self.args_sampler["sampler"] == "node":
            self.sampler = NodeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "edge":
            self.sampler = EdgeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "rw":
            self.sampler = RWSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "mrw":
            self.sampler = MRWSampler(self.data, self.args_sampler)

        self.optimizer = torch.optim.Adam(model.parameters(),
                                          lr=self.lr,
                                          weight_decay=self.weight_decay)
        epoch_iter = tqdm(range(self.max_epoch))
        patience = 0
        best_score = 0
        best_loss = np.inf
        max_score = 0
        min_loss = np.inf
        for epoch in epoch_iter:
            self._train_step()
            train_acc, _ = self._test_step(split="train")
            val_acc, val_loss = self._test_step(split="val")
            epoch_iter.set_description(
                f"Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}"
            )
            if val_loss <= min_loss or val_acc >= max_score:
                if val_acc >= best_score:  # SAINT loss is not accurate
                    best_loss = val_loss
                    best_score = val_acc
                    best_model = copy.deepcopy(self.model)
                min_loss = np.min((min_loss, val_loss))
                max_score = np.max((max_score, val_acc))
                patience = 0
            else:
                patience += 1
                if patience == self.patience:
                    self.model = best_model
                    epoch_iter.close()
                    break
        return best_model

    def _train_step(self):
        self.data = self.sampler.get_subgraph("train")
        self.model.train()
        self.optimizer.zero_grad()
        self.model.loss(self.data).backward()
        self.optimizer.step()

    def _test_step(self, split="val"):
        self.data = self.sampler.get_subgraph(split)

        self.model.eval()
        if split == "train":
            mask = self.data.train_mask
        elif split == "val":
            mask = self.data.val_mask
        else:
            mask = self.data.test_mask

        logits = self.model.predict(self.data)
        loss = (torch.nn.NLLLoss(reduction='none')(logits[mask],
                                                   self.data.y[mask]) *
                self.data.norm_loss[mask]).sum()

        pred = logits[mask].max(1)[1]
        acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item()
        return acc, loss