示例#1
0
class HAN(Han, pl.LightningModule):
    def __init__(self,
                 hparams,
                 dataset: HeteroNetDataset,
                 metrics=["precision"]):
        num_edge = len(dataset.edge_index_dict)
        num_layers = hparams.num_layers
        num_class = dataset.n_classes
        self.collate_fn = hparams.collate_fn
        self.multilabel = dataset.multilabel
        num_nodes = dataset.num_nodes_dict[dataset.head_node_type]

        if dataset.in_features:
            w_in = dataset.in_features
        else:
            w_in = hparams.embedding_dim

        w_out = hparams.embedding_dim

        super(HAN, self).__init__(num_edge=num_edge,
                                  w_in=w_in,
                                  w_out=w_out,
                                  num_class=num_class,
                                  num_nodes=num_nodes,
                                  num_layers=num_layers)

        if not hasattr(dataset, "x") and not hasattr(dataset, "x_dict"):
            if num_nodes > 10000:
                self.embedding = {
                    dataset.head_node_type:
                    torch.nn.Embedding(
                        num_embeddings=num_nodes,
                        embedding_dim=hparams.embedding_dim).cpu()
                }
            else:
                self.embedding = torch.nn.Embedding(
                    num_embeddings=num_nodes,
                    embedding_dim=hparams.embedding_dim)

        self.dataset = dataset
        self.head_node_type = self.dataset.head_node_type
        hparams.n_params = self.get_n_params()
        self.train_metrics = Metrics(prefix="",
                                     loss_type=hparams.loss_type,
                                     n_classes=dataset.n_classes,
                                     multilabel=dataset.multilabel,
                                     metrics=metrics)
        self.valid_metrics = Metrics(prefix="val_",
                                     loss_type=hparams.loss_type,
                                     n_classes=dataset.n_classes,
                                     multilabel=dataset.multilabel,
                                     metrics=metrics)
        self.test_metrics = Metrics(prefix="test_",
                                    loss_type=hparams.loss_type,
                                    n_classes=dataset.n_classes,
                                    multilabel=dataset.multilabel,
                                    metrics=metrics)
        hparams.name = self.name()
        hparams.inductive = dataset.inductive
        self.hparams = hparams

    def name(self):
        if hasattr(self, "_name"):
            return self._name
        else:
            return self.__class__.__name__

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean().item()
        logs = self.train_metrics.compute_metrics()
        # logs = _fix_dp_return_type(logs, device=outputs[0]["loss"].device)

        logs.update({"loss": avg_loss})
        self.train_metrics.reset_metrics()
        return {"log": logs}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean().item()
        logs = self.valid_metrics.compute_metrics()
        # logs = _fix_dp_return_type(logs, device=outputs[0]["val_loss"].device)
        # print({k: np.around(v.item(), decimals=3) for k, v in logs.items()})

        logs.update({"val_loss": avg_loss})
        self.valid_metrics.reset_metrics()
        return {"progress_bar": logs, "log": logs}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean().item()
        if hasattr(self, "test_metrics"):
            logs = self.test_metrics.compute_metrics()
            self.test_metrics.reset_metrics()
        else:
            logs = {}
        logs.update({"test_loss": avg_loss})

        return {"progress_bar": logs, "log": logs}

    def print_pred_class_counts(self, y_hat, y, multilabel, n_top_class=8):
        if multilabel:
            y_pred_dict = pd.Series(
                y_hat.sum(1).detach().cpu().type(
                    torch.int).numpy()).value_counts().to_dict()
            y_true_dict = pd.Series(
                y.sum(1).detach().cpu().type(
                    torch.int).numpy()).value_counts().to_dict()
            print(
                f"y_pred {len(y_pred_dict)} classes", {
                    str(k): v
                    for k, v in itertools.islice(y_pred_dict.items(),
                                                 n_top_class)
                })
            print(
                f"y_true {len(y_true_dict)} classes", {
                    str(k): v
                    for k, v in itertools.islice(y_true_dict.items(),
                                                 n_top_class)
                })
        else:
            y_pred_dict = pd.Series(
                y_hat.argmax(1).detach().cpu().type(
                    torch.int).numpy()).value_counts().to_dict()
            y_true_dict = pd.Series(y.detach().cpu().type(
                torch.int).numpy()).value_counts().to_dict()
            print(
                f"y_pred {len(y_pred_dict)} classes", {
                    str(k): v
                    for k, v in itertools.islice(y_pred_dict.items(),
                                                 n_top_class)
                })
            print(
                f"y_true {len(y_true_dict)} classes", {
                    str(k): v
                    for k, v in itertools.islice(y_true_dict.items(),
                                                 n_top_class)
                })

    def get_n_params(self):
        size = 0
        for name, param in dict(self.named_parameters()).items():
            nn = 1
            for s in list(param.size()):
                nn = nn * s
            size += nn
        return size

    def forward(self, A, X, x_idx):
        if X is None:
            if isinstance(self.embedding, dict):
                X = self.embedding[self.head_node_type].weight[x_idx].to(
                    self.layers[0].device)
            else:
                X = self.embedding.weight[x_idx]

        for i in range(self.num_layers):
            X = self.layers[i].forward(
                X,
                A,
            )

        if x_idx is not None and X.size(0) > x_idx.size(0):
            y = self.linear(X[x_idx])
        else:
            y = self.linear(X)
        return y

    def loss(self, y_hat, y):
        if not self.multilabel:
            loss = self.cross_entropy_loss(y_hat, y)
        else:
            loss = F.binary_cross_entropy_with_logits(y_hat, y.type_as(y_hat))
        return loss

    def training_step(self, batch, batch_nb):
        X, y, weights = batch
        y_hat = self.forward(X["adj"], X["x"], X["idx"])
        y_hat, y = filter_samples(Y_hat=y_hat, Y=y, weights=weights)
        self.train_metrics.update_metrics(y_hat, y, weights=None)
        loss = self.loss(y_hat, y)
        return {'loss': loss}

    def validation_step(self, batch, batch_nb):
        X, y, weights = batch
        y_hat = self.forward(X["adj"], X["x"], X["idx"])
        y_hat, y = filter_samples(Y_hat=y_hat, Y=y, weights=weights)
        self.valid_metrics.update_metrics(y_hat, y, weights=None)
        loss = self.loss(y_hat, y)

        return {"val_loss": loss}

    def test_step(self, batch, batch_nb):
        X, y, weights = batch
        y_hat = self.forward(X["adj"], X["x"], X["idx"])
        y_hat, y = filter_samples(Y_hat=y_hat, Y=y, weights=weights)
        self.test_metrics.update_metrics(y_hat, y, weights=None)
        loss = self.loss(y_hat, y)

        return {"test_loss": loss}

    def train_dataloader(self):
        return self.dataset.train_dataloader(
            collate_fn=self.collate_fn, batch_size=self.hparams.batch_size)

    def val_dataloader(self):
        return self.dataset.valid_dataloader(
            collate_fn=self.collate_fn, batch_size=self.hparams.batch_size)

    def valtrain_dataloader(self):
        return self.dataset.valtrain_dataloader(
            collate_fn=self.collate_fn, batch_size=self.hparams.batch_size)

    def test_dataloader(self):
        return self.dataset.test_dataloader(collate_fn=self.collate_fn,
                                            batch_size=self.hparams.batch_size)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
示例#2
0
class NodeClfMetrics(pl.LightningModule):
    def __init__(self, hparams, dataset, metrics, *args):
        super().__init__(*args)

        self.train_metrics = Metrics(prefix="",
                                     loss_type=hparams.loss_type,
                                     n_classes=dataset.n_classes,
                                     multilabel=dataset.multilabel,
                                     metrics=metrics)
        self.valid_metrics = Metrics(prefix="val_",
                                     loss_type=hparams.loss_type,
                                     n_classes=dataset.n_classes,
                                     multilabel=dataset.multilabel,
                                     metrics=metrics)
        self.test_metrics = Metrics(prefix="test_",
                                    loss_type=hparams.loss_type,
                                    n_classes=dataset.n_classes,
                                    multilabel=dataset.multilabel,
                                    metrics=metrics)
        hparams.name = self.name()
        hparams.inductive = dataset.inductive
        self.hparams = hparams

    def name(self):
        if hasattr(self, "_name"):
            return self._name
        else:
            return self.__class__.__name__

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean().item()
        logs = self.train_metrics.compute_metrics()

        logs.update({"loss": avg_loss})
        self.train_metrics.reset_metrics()
        return {"log": logs}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean().item()
        logs = self.valid_metrics.compute_metrics()

        logs.update({"val_loss": avg_loss})
        self.valid_metrics.reset_metrics()
        return {"progress_bar": logs, "log": logs}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean().item()
        if hasattr(self, "test_metrics"):
            logs = self.test_metrics.compute_metrics()
            self.test_metrics.reset_metrics()
        else:
            logs = {}
        logs.update({"test_loss": avg_loss})

        return {"progress_bar": logs, "log": logs}

    def print_pred_class_counts(self, y_hat, y, multilabel, n_top_class=8):
        if multilabel:
            y_pred_dict = pd.Series(
                y_hat.sum(1).detach().cpu().type(
                    torch.int).numpy()).value_counts().to_dict()
            y_true_dict = pd.Series(
                y.sum(1).detach().cpu().type(
                    torch.int).numpy()).value_counts().to_dict()
            print(
                f"y_pred {len(y_pred_dict)} classes", {
                    str(k): v
                    for k, v in itertools.islice(y_pred_dict.items(),
                                                 n_top_class)
                })
            print(
                f"y_true {len(y_true_dict)} classes", {
                    str(k): v
                    for k, v in itertools.islice(y_true_dict.items(),
                                                 n_top_class)
                })
        else:
            y_pred_dict = pd.Series(
                y_hat.argmax(1).detach().cpu().type(
                    torch.int).numpy()).value_counts().to_dict()
            y_true_dict = pd.Series(y.detach().cpu().type(
                torch.int).numpy()).value_counts().to_dict()
            print(
                f"y_pred {len(y_pred_dict)} classes", {
                    str(k): v
                    for k, v in itertools.islice(y_pred_dict.items(),
                                                 n_top_class)
                })
            print(
                f"y_true {len(y_true_dict)} classes", {
                    str(k): v
                    for k, v in itertools.islice(y_true_dict.items(),
                                                 n_top_class)
                })

    def get_n_params(self):
        size = 0
        for name, param in dict(self.named_parameters()).items():
            nn = 1
            for s in list(param.size()):
                nn = nn * s
            size += nn
        return size