class HAN(Han, pl.LightningModule): def __init__(self, hparams, dataset: HeteroNetDataset, metrics=["precision"]): num_edge = len(dataset.edge_index_dict) num_layers = hparams.num_layers num_class = dataset.n_classes self.collate_fn = hparams.collate_fn self.multilabel = dataset.multilabel num_nodes = dataset.num_nodes_dict[dataset.head_node_type] if dataset.in_features: w_in = dataset.in_features else: w_in = hparams.embedding_dim w_out = hparams.embedding_dim super(HAN, self).__init__(num_edge=num_edge, w_in=w_in, w_out=w_out, num_class=num_class, num_nodes=num_nodes, num_layers=num_layers) if not hasattr(dataset, "x") and not hasattr(dataset, "x_dict"): if num_nodes > 10000: self.embedding = { dataset.head_node_type: torch.nn.Embedding( num_embeddings=num_nodes, embedding_dim=hparams.embedding_dim).cpu() } else: self.embedding = torch.nn.Embedding( num_embeddings=num_nodes, embedding_dim=hparams.embedding_dim) self.dataset = dataset self.head_node_type = self.dataset.head_node_type hparams.n_params = self.get_n_params() self.train_metrics = Metrics(prefix="", loss_type=hparams.loss_type, n_classes=dataset.n_classes, multilabel=dataset.multilabel, metrics=metrics) self.valid_metrics = Metrics(prefix="val_", loss_type=hparams.loss_type, n_classes=dataset.n_classes, multilabel=dataset.multilabel, metrics=metrics) self.test_metrics = Metrics(prefix="test_", loss_type=hparams.loss_type, n_classes=dataset.n_classes, multilabel=dataset.multilabel, metrics=metrics) hparams.name = self.name() hparams.inductive = dataset.inductive self.hparams = hparams def name(self): if hasattr(self, "_name"): return self._name else: return self.__class__.__name__ def training_epoch_end(self, outputs): avg_loss = torch.stack([x["loss"] for x in outputs]).mean().item() logs = self.train_metrics.compute_metrics() # logs = _fix_dp_return_type(logs, device=outputs[0]["loss"].device) logs.update({"loss": avg_loss}) self.train_metrics.reset_metrics() return {"log": logs} def validation_epoch_end(self, outputs): avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean().item() logs = self.valid_metrics.compute_metrics() # logs = _fix_dp_return_type(logs, device=outputs[0]["val_loss"].device) # print({k: np.around(v.item(), decimals=3) for k, v in logs.items()}) logs.update({"val_loss": avg_loss}) self.valid_metrics.reset_metrics() return {"progress_bar": logs, "log": logs} def test_epoch_end(self, outputs): avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean().item() if hasattr(self, "test_metrics"): logs = self.test_metrics.compute_metrics() self.test_metrics.reset_metrics() else: logs = {} logs.update({"test_loss": avg_loss}) return {"progress_bar": logs, "log": logs} def print_pred_class_counts(self, y_hat, y, multilabel, n_top_class=8): if multilabel: y_pred_dict = pd.Series( y_hat.sum(1).detach().cpu().type( torch.int).numpy()).value_counts().to_dict() y_true_dict = pd.Series( y.sum(1).detach().cpu().type( torch.int).numpy()).value_counts().to_dict() print( f"y_pred {len(y_pred_dict)} classes", { str(k): v for k, v in itertools.islice(y_pred_dict.items(), n_top_class) }) print( f"y_true {len(y_true_dict)} classes", { str(k): v for k, v in itertools.islice(y_true_dict.items(), n_top_class) }) else: y_pred_dict = pd.Series( y_hat.argmax(1).detach().cpu().type( torch.int).numpy()).value_counts().to_dict() y_true_dict = pd.Series(y.detach().cpu().type( torch.int).numpy()).value_counts().to_dict() print( f"y_pred {len(y_pred_dict)} classes", { str(k): v for k, v in itertools.islice(y_pred_dict.items(), n_top_class) }) print( f"y_true {len(y_true_dict)} classes", { str(k): v for k, v in itertools.islice(y_true_dict.items(), n_top_class) }) def get_n_params(self): size = 0 for name, param in dict(self.named_parameters()).items(): nn = 1 for s in list(param.size()): nn = nn * s size += nn return size def forward(self, A, X, x_idx): if X is None: if isinstance(self.embedding, dict): X = self.embedding[self.head_node_type].weight[x_idx].to( self.layers[0].device) else: X = self.embedding.weight[x_idx] for i in range(self.num_layers): X = self.layers[i].forward( X, A, ) if x_idx is not None and X.size(0) > x_idx.size(0): y = self.linear(X[x_idx]) else: y = self.linear(X) return y def loss(self, y_hat, y): if not self.multilabel: loss = self.cross_entropy_loss(y_hat, y) else: loss = F.binary_cross_entropy_with_logits(y_hat, y.type_as(y_hat)) return loss def training_step(self, batch, batch_nb): X, y, weights = batch y_hat = self.forward(X["adj"], X["x"], X["idx"]) y_hat, y = filter_samples(Y_hat=y_hat, Y=y, weights=weights) self.train_metrics.update_metrics(y_hat, y, weights=None) loss = self.loss(y_hat, y) return {'loss': loss} def validation_step(self, batch, batch_nb): X, y, weights = batch y_hat = self.forward(X["adj"], X["x"], X["idx"]) y_hat, y = filter_samples(Y_hat=y_hat, Y=y, weights=weights) self.valid_metrics.update_metrics(y_hat, y, weights=None) loss = self.loss(y_hat, y) return {"val_loss": loss} def test_step(self, batch, batch_nb): X, y, weights = batch y_hat = self.forward(X["adj"], X["x"], X["idx"]) y_hat, y = filter_samples(Y_hat=y_hat, Y=y, weights=weights) self.test_metrics.update_metrics(y_hat, y, weights=None) loss = self.loss(y_hat, y) return {"test_loss": loss} def train_dataloader(self): return self.dataset.train_dataloader( collate_fn=self.collate_fn, batch_size=self.hparams.batch_size) def val_dataloader(self): return self.dataset.valid_dataloader( collate_fn=self.collate_fn, batch_size=self.hparams.batch_size) def valtrain_dataloader(self): return self.dataset.valtrain_dataloader( collate_fn=self.collate_fn, batch_size=self.hparams.batch_size) def test_dataloader(self): return self.dataset.test_dataloader(collate_fn=self.collate_fn, batch_size=self.hparams.batch_size) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
class NodeClfMetrics(pl.LightningModule): def __init__(self, hparams, dataset, metrics, *args): super().__init__(*args) self.train_metrics = Metrics(prefix="", loss_type=hparams.loss_type, n_classes=dataset.n_classes, multilabel=dataset.multilabel, metrics=metrics) self.valid_metrics = Metrics(prefix="val_", loss_type=hparams.loss_type, n_classes=dataset.n_classes, multilabel=dataset.multilabel, metrics=metrics) self.test_metrics = Metrics(prefix="test_", loss_type=hparams.loss_type, n_classes=dataset.n_classes, multilabel=dataset.multilabel, metrics=metrics) hparams.name = self.name() hparams.inductive = dataset.inductive self.hparams = hparams def name(self): if hasattr(self, "_name"): return self._name else: return self.__class__.__name__ def training_epoch_end(self, outputs): avg_loss = torch.stack([x["loss"] for x in outputs]).mean().item() logs = self.train_metrics.compute_metrics() logs.update({"loss": avg_loss}) self.train_metrics.reset_metrics() return {"log": logs} def validation_epoch_end(self, outputs): avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean().item() logs = self.valid_metrics.compute_metrics() logs.update({"val_loss": avg_loss}) self.valid_metrics.reset_metrics() return {"progress_bar": logs, "log": logs} def test_epoch_end(self, outputs): avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean().item() if hasattr(self, "test_metrics"): logs = self.test_metrics.compute_metrics() self.test_metrics.reset_metrics() else: logs = {} logs.update({"test_loss": avg_loss}) return {"progress_bar": logs, "log": logs} def print_pred_class_counts(self, y_hat, y, multilabel, n_top_class=8): if multilabel: y_pred_dict = pd.Series( y_hat.sum(1).detach().cpu().type( torch.int).numpy()).value_counts().to_dict() y_true_dict = pd.Series( y.sum(1).detach().cpu().type( torch.int).numpy()).value_counts().to_dict() print( f"y_pred {len(y_pred_dict)} classes", { str(k): v for k, v in itertools.islice(y_pred_dict.items(), n_top_class) }) print( f"y_true {len(y_true_dict)} classes", { str(k): v for k, v in itertools.islice(y_true_dict.items(), n_top_class) }) else: y_pred_dict = pd.Series( y_hat.argmax(1).detach().cpu().type( torch.int).numpy()).value_counts().to_dict() y_true_dict = pd.Series(y.detach().cpu().type( torch.int).numpy()).value_counts().to_dict() print( f"y_pred {len(y_pred_dict)} classes", { str(k): v for k, v in itertools.islice(y_pred_dict.items(), n_top_class) }) print( f"y_true {len(y_true_dict)} classes", { str(k): v for k, v in itertools.islice(y_true_dict.items(), n_top_class) }) def get_n_params(self): size = 0 for name, param in dict(self.named_parameters()).items(): nn = 1 for s in list(param.size()): nn = nn * s size += nn return size