Exemplo n.º 1
0
    def fit(self, model: SupervisedModel, dataset: Dataset):
        self.dataset = dataset
        self.data = dataset.data
        self.model = model.to(self.device)
        self.evaluator = dataset.get_evaluator()
        self.loss_fn = dataset.get_loss_fn()

        if self.args_sampler["sampler"] == "node":
            self.sampler = NodeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "edge":
            self.sampler = EdgeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "rw":
            self.sampler = RWSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "mrw":
            self.sampler = MRWSampler(self.data, self.args_sampler)
        else:
            raise NotImplementedError

        # self.train_dataset = SAINTDataset(dataset, self.args_sampler)
        # self.train_loader = SAINTDataLoader(
        #     dataset=train_dataset,
        #     num_workers=self.num_workers,
        #     persistent_workers=True,
        #     pin_memory=True
        # )
        # self.set_data_model(dataset, model)
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.lr,
                                          weight_decay=self.weight_decay)
        return self.train()
Exemplo n.º 2
0
    def fit(self, model: SupervisedHeterogeneousNodeClassificationModel,
            dataset: Dataset):
        self.data = dataset.data
        self.data.apply(lambda x: x.to(self.device))
        self.model = model
        if self.args_sampler["sampler"] == "node":
            self.sampler = NodeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "edge":
            self.sampler = EdgeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "rw":
            self.sampler = RWSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "mrw":
            self.sampler = MRWSampler(self.data, self.args_sampler)

        self.optimizer = torch.optim.Adam(model.parameters(),
                                          lr=self.lr,
                                          weight_decay=self.weight_decay)
        epoch_iter = tqdm(range(self.max_epoch))
        patience = 0
        best_score = 0
        best_loss = np.inf
        max_score = 0
        min_loss = np.inf
        for epoch in epoch_iter:
            self._train_step()
            train_acc, _ = self._test_step(split="train")
            val_acc, val_loss = self._test_step(split="val")
            epoch_iter.set_description(
                f"Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}"
            )
            if val_loss <= min_loss or val_acc >= max_score:
                if val_acc >= best_score:  # SAINT loss is not accurate
                    best_loss = val_loss
                    best_score = val_acc
                    best_model = copy.deepcopy(self.model)
                min_loss = np.min((min_loss, val_loss))
                max_score = np.max((max_score, val_acc))
                patience = 0
            else:
                patience += 1
                if patience == self.patience:
                    self.model = best_model
                    epoch_iter.close()
                    break
        return best_model
Exemplo n.º 3
0
    def fit(self, model: SupervisedHeterogeneousNodeClassificationModel, dataset: Dataset):
        self.data = dataset.data
        self.data.apply(lambda x: x.to(self.device))
        self.model = model
        if self.args_sampler["sampler"] == "node":
            self.sampler = NodeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "edge":
            self.sampler = EdgeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "rw":
            self.sampler = RWSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "mrw":
            self.sampler = MRWSampler(self.data, self.args_sampler)

        self.optimizer = torch.optim.Adam(
            model.parameters(), lr=self.lr, weight_decay=self.weight_decay
        )
        best_model = self.train()
        self.model = best_model
        return self.model
Exemplo n.º 4
0
    def set_data_model(self, dataset: Dataset, model: SupervisedModel):
        self.dataset = dataset
        self.data = dataset.data
        self.model = model.to(self.device)
        self.evaluator = dataset.get_evaluator()
        self.loss_fn = dataset.get_loss_fn()

        if self.args_sampler["sampler"] == "node":
            self.sampler = NodeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "edge":
            self.sampler = EdgeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "rw":
            self.sampler = RWSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "mrw":
            self.sampler = MRWSampler(self.data, self.args_sampler)
        else:
            raise NotImplementedError

        self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
Exemplo n.º 5
0
    def fit(self, model: SupervisedModel, dataset: Dataset):
        self.data = dataset.data
        self.data.apply(lambda x: x.to(self.device))
        self.model = model.to(self.device)
        self.evaluator = dataset.get_evaluator()
        self.loss_fn = dataset.get_loss_fn()

        if self.args_sampler["sampler"] == "node":
            self.sampler = NodeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "edge":
            self.sampler = EdgeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "rw":
            self.sampler = RWSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "mrw":
            self.sampler = MRWSampler(self.data, self.args_sampler)

        self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        best_model = self.train()
        self.model = best_model
        return self.model
Exemplo n.º 6
0
class SAINTTrainer(SampledTrainer):
    def __init__(self, args):
        super(SAINTTrainer, self).__init__(args)
        self.args_sampler = self.sampler_from_args(args)

    @classmethod
    def build_trainer_from_args(cls, args):
        return cls(args)

    def sampler_from_args(self, args):
        args_sampler = {
            "sampler": args.sampler,
            "sample_coverage": args.sample_coverage,
            "size_subgraph": args.size_subgraph,
            "num_walks": args.num_walks,
            "walk_length": args.walk_length,
            "size_frontier": args.size_frontier
        }
        return args_sampler

    def fit(self, model: SupervisedHeterogeneousNodeClassificationModel, dataset: Dataset):
        self.data = dataset.data
        self.data.apply(lambda x: x.to(self.device))
        self.model = model
        if self.args_sampler["sampler"] == "node":
            self.sampler = NodeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "edge":
            self.sampler = EdgeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "rw":
            self.sampler = RWSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "mrw":
            self.sampler = MRWSampler(self.data, self.args_sampler)

        self.optimizer = torch.optim.Adam(
            model.parameters(), lr=self.lr, weight_decay=self.weight_decay
        )
        best_model = self.train()
        self.model = best_model
        return self.model

    def _train_step(self):
        self.data = self.sampler.get_subgraph("train")
        self.data.apply(lambda x: x.to(self.device))
        self.model.train()
        self.optimizer.zero_grad()
        self.model.loss(self.data).backward()
        self.optimizer.step()

    def _test_step(self, split="val"):
        self.data = self.sampler.get_subgraph(split)
        self.data.apply(lambda x: x.to(self.device))
        self.model.eval()
        if split == "train":
            mask = self.data.train_mask
        elif split == "val":
            mask = self.data.val_mask
        else:
            mask = self.data.test_mask

        with torch.no_grad():
            logits = self.model.predict(self.data)
            loss = (torch.nn.NLLLoss(reduction="none")(logits[mask], self.data.y[mask]) * self.data.norm_loss[mask]).sum()

        pred = logits[mask].max(1)[1]
        acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item()
        return acc, loss
Exemplo n.º 7
0
class SAINTTrainer(SampledTrainer):
    @staticmethod
    def add_args(parser: argparse.ArgumentParser):
        """Add trainer-specific arguments to the parser."""
        # fmt: off
        parser.add_argument('--sampler', default='none', type=str, help='graph samplers')
        parser.add_argument('--sample-coverage', default=20, type=float, help='sample coverage ratio')
        parser.add_argument('--size-subgraph', default=1200, type=int, help='subgraph size')
        parser.add_argument('--num-walks', default=50, type=int, help='number of random walks')
        parser.add_argument('--walk-length', default=20, type=int, help='random walk length')
        parser.add_argument('--size-frontier', default=20, type=int, help='frontier size in multidimensional random walks')
        # fmt: on

    @classmethod
    def build_trainer_from_args(cls, args):
        return cls(args)

    def __init__(self, args):
        super(SAINTTrainer, self).__init__(args)
        self.args_sampler = self.sampler_from_args(args)

    def sampler_from_args(self, args):
        args_sampler = {
            "sampler": args.sampler,
            "sample_coverage": args.sample_coverage,
            "size_subgraph": args.size_subgraph,
            "num_walks": args.num_walks,
            "walk_length": args.walk_length,
            "size_frontier": args.size_frontier,
        }
        return args_sampler

    def fit(self, model: SupervisedModel, dataset: Dataset):
        self.data = dataset.data
        self.data.apply(lambda x: x.to(self.device))
        self.model = model.to(self.device)
        self.evaluator = dataset.get_evaluator()
        self.loss_fn = dataset.get_loss_fn()

        if self.args_sampler["sampler"] == "node":
            self.sampler = NodeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "edge":
            self.sampler = EdgeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "rw":
            self.sampler = RWSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "mrw":
            self.sampler = MRWSampler(self.data, self.args_sampler)

        self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        best_model = self.train()
        self.model = best_model
        return self.model

    def _train_step(self):
        self.data = self.sampler.get_subgraph("train")
        self.data.apply(lambda x: x.to(self.device))
        self.model.train()
        self.optimizer.zero_grad()
        self.model.node_classification_loss(self.data).backward()
        self.optimizer.step()

    def _test_step(self, split="val"):
        self.data = self.sampler.get_subgraph(split)
        self.data.apply(lambda x: x.to(self.device))
        self.model.eval()
        masks = {"train": self.data.train_mask, "val": self.data.val_mask, "test": self.data.test_mask}
        with torch.no_grad():
            logits = self.model.predict(self.data)
        # self.loss_fn(logits[val], self.data.y[val]) * self.data.norm_loss[val]
        loss = {key: self.loss_fn(logits[val], self.data.y[val]) for key, val in masks.items()}
        metric = {key: self.evaluator(logits[val], self.data.y[val]) for key, val in masks.items()}
        return metric, loss
Exemplo n.º 8
0
class SAINTTrainer(SampledTrainer):
    @staticmethod
    def add_args(parser: argparse.ArgumentParser):
        """Add trainer-specific arguments to the parser."""
        # fmt: off
        parser.add_argument('--sampler',
                            default='none',
                            type=str,
                            help='graph samplers')
        parser.add_argument('--sample-coverage',
                            default=20,
                            type=float,
                            help='sample coverage ratio')
        parser.add_argument('--size-subgraph',
                            default=1200,
                            type=int,
                            help='subgraph size')
        parser.add_argument('--num-walks',
                            default=50,
                            type=int,
                            help='number of random walks')
        parser.add_argument('--walk-length',
                            default=20,
                            type=int,
                            help='random walk length')
        parser.add_argument(
            '--size-frontier',
            default=20,
            type=int,
            help='frontier size in multidimensional random walks')
        parser.add_argument('--valid-cpu',
                            action='store_true',
                            help='run validation on cpu')
        # fmt: on

    @classmethod
    def build_trainer_from_args(cls, args):
        return cls(args)

    def __init__(self, args):
        super(SAINTTrainer, self).__init__(args)
        self.valid_cpu = args.valid_cpu
        self.args_sampler = self.sampler_from_args(args)

    def sampler_from_args(self, args):
        args_sampler = {
            "sampler": args.sampler,
            "sample_coverage": args.sample_coverage,
            "size_subgraph": args.size_subgraph,
            "num_walks": args.num_walks,
            "walk_length": args.walk_length,
            "size_frontier": args.size_frontier,
        }
        return args_sampler

    def set_data_model(self, dataset: Dataset, model: SupervisedModel):
        self.dataset = dataset
        self.data = dataset.data
        self.model = model.to(self.device)
        self.evaluator = dataset.get_evaluator()
        self.loss_fn = dataset.get_loss_fn()

        if self.args_sampler["sampler"] == "node":
            self.sampler = NodeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "edge":
            self.sampler = EdgeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "rw":
            self.sampler = RWSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "mrw":
            self.sampler = MRWSampler(self.data, self.args_sampler)
        else:
            raise NotImplementedError

        self.optimizer = torch.optim.Adam(model.parameters(),
                                          lr=self.lr,
                                          weight_decay=self.weight_decay)

    def fit(self, model: SupervisedModel, dataset: Dataset):
        self.set_data_model(dataset, model)
        return self.train()

    def _train_step(self):
        self.data = self.sampler.one_batch("train")
        self.data.apply(lambda x: x.to(self.device))

        self.model = self.model.to(self.device)
        self.model.train()
        self.optimizer.zero_grad()

        mask = self.data.train_mask
        if len(self.data.y.shape) > 1:
            logits = self.model.predict(self.data)
            weight = self.data.norm_loss[mask].unsqueeze(1)
            loss = torch.nn.BCEWithLogitsLoss(reduction="sum", weight=weight)(
                logits[mask], self.data.y[mask].float())
        else:
            logits = torch.nn.functional.log_softmax(
                self.model.predict(self.data))
            loss = (torch.nn.NLLLoss(reduction="none")(logits[mask],
                                                       self.data.y[mask]) *
                    self.data.norm_loss[mask]).sum()
        loss.backward()
        self.optimizer.step()

    def _test_step(self, split="val"):
        self.data = self.sampler.one_batch(split)
        if split != "train" and self.valid_cpu:
            self.model = self.model.cpu()
        else:
            self.data.apply(lambda x: x.to(self.device))
        self.model.eval()
        masks = {
            "train": self.data.train_mask,
            "val": self.data.val_mask,
            "test": self.data.test_mask
        }
        with torch.no_grad():
            logits = self.model.predict(self.data)

        # if isinstance(self.dataset, SAINTDataset):
        #     weight = self.data.norm_loss.unsqueeze(1)
        #     loss = torch.nn.BCEWithLogitsLoss(reduction="sum", weight=weight)(logits, self.data.y.float())
        #     metric = multilabel_f1(logits[mask], self.data.y[mask])
        # else:
        #     loss = (
        #         torch.nn.NLLLoss(reduction="none")(F.log_softmax(logits[mask]), self.data.y[mask])
        #         * self.data.norm_loss[mask]
        #     ).sum()
        #     pred = logits[mask].max(1)[1]
        #     metric = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item()

        loss = {
            key: self.loss_fn(logits[val], self.data.y[val])
            for key, val in masks.items()
        }
        metric = {
            key: self.evaluator(logits[val], self.data.y[val])
            for key, val in masks.items()
        }
        return metric, loss
Exemplo n.º 9
0
class SAINTTrainer(SampledTrainer):
    def __init__(self, args):
        self.device = args.device_id[0] if not args.cpu else "cpu"
        self.patience = args.patience
        self.max_epoch = args.max_epoch
        self.lr = args.lr
        self.weight_decay = args.weight_decay
        self.args_sampler = self.sampler_from_args(args)

    @staticmethod
    def build_trainer_from_args(args):
        pass

    def sampler_from_args(self, args):
        args_sampler = {}
        args_sampler["sampler"] = args.sampler
        args_sampler["sample_coverage"] = args.sample_coverage
        args_sampler["size_subgraph"] = args.size_subgraph
        args_sampler["num_walks"] = args.num_walks
        args_sampler["walk_length"] = args.walk_length
        args_sampler["size_frontier"] = args.size_frontier
        return args_sampler

    def fit(self, model: SupervisedHeterogeneousNodeClassificationModel,
            dataset: Dataset):
        self.data = dataset.data
        self.data.apply(lambda x: x.to(self.device))
        self.model = model
        if self.args_sampler["sampler"] == "node":
            self.sampler = NodeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "edge":
            self.sampler = EdgeSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "rw":
            self.sampler = RWSampler(self.data, self.args_sampler)
        elif self.args_sampler["sampler"] == "mrw":
            self.sampler = MRWSampler(self.data, self.args_sampler)

        self.optimizer = torch.optim.Adam(model.parameters(),
                                          lr=self.lr,
                                          weight_decay=self.weight_decay)
        epoch_iter = tqdm(range(self.max_epoch))
        patience = 0
        best_score = 0
        best_loss = np.inf
        max_score = 0
        min_loss = np.inf
        for epoch in epoch_iter:
            self._train_step()
            train_acc, _ = self._test_step(split="train")
            val_acc, val_loss = self._test_step(split="val")
            epoch_iter.set_description(
                f"Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}"
            )
            if val_loss <= min_loss or val_acc >= max_score:
                if val_acc >= best_score:  # SAINT loss is not accurate
                    best_loss = val_loss
                    best_score = val_acc
                    best_model = copy.deepcopy(self.model)
                min_loss = np.min((min_loss, val_loss))
                max_score = np.max((max_score, val_acc))
                patience = 0
            else:
                patience += 1
                if patience == self.patience:
                    self.model = best_model
                    epoch_iter.close()
                    break
        return best_model

    def _train_step(self):
        self.data = self.sampler.get_subgraph("train")
        self.model.train()
        self.optimizer.zero_grad()
        self.model.loss(self.data).backward()
        self.optimizer.step()

    def _test_step(self, split="val"):
        self.data = self.sampler.get_subgraph(split)

        self.model.eval()
        if split == "train":
            mask = self.data.train_mask
        elif split == "val":
            mask = self.data.val_mask
        else:
            mask = self.data.test_mask

        logits = self.model.predict(self.data)
        loss = (torch.nn.NLLLoss(reduction='none')(logits[mask],
                                                   self.data.y[mask]) *
                self.data.norm_loss[mask]).sum()

        pred = logits[mask].max(1)[1]
        acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item()
        return acc, loss