def __init__(self, hparams, dataset: DGLNodeSampler, metrics=["accuracy"], collate_fn="neighbor_sampler") -> None:
        super(HGTNodeClassifier, self).__init__(hparams=hparams, dataset=dataset, metrics=metrics)
        self.head_node_type = dataset.head_node_type
        self.dataset = dataset
        self.multilabel = dataset.multilabel
        self.y_types = list(dataset.y_dict.keys())
        self._name = f"LATTE-{hparams.t_order}{' proximity' if hparams.use_proximity else ''}"
        self.collate_fn = collate_fn

        # self.latte = LATTE(t_order=hparams.t_order, embedding_dim=hparams.embedding_dim,
        #                    in_channels_dict=dataset.node_attr_shape, num_nodes_dict=dataset.num_nodes_dict,
        #                    metapaths=dataset.get_metapaths(), activation=hparams.activation,
        #                    attn_heads=hparams.attn_heads, attn_activation=hparams.attn_activation,
        #                    attn_dropout=hparams.attn_dropout, use_proximity=hparams.use_proximity,
        #                    neg_sampling_ratio=hparams.neg_sampling_ratio)
        # hparams.embedding_dim = hparams.embedding_dim * hparams.t_order

        self.embedder = HGT(node_dict={ntype: i for i, ntype in enumerate(dataset.node_types)},
                            edge_dict={metapath[1]: i for i, metapath in enumerate(dataset.get_metapaths())},
                            n_inp=self.dataset.node_attr_shape[self.head_node_type],
                            n_hid=hparams.embedding_dim, n_out=hparams.embedding_dim,
                            n_layers=len(self.dataset.neighbor_sizes),
                            n_heads=hparams.attn_heads)

        self.classifier = DenseClassification(hparams)

        self.criterion = ClassificationLoss(n_classes=dataset.n_classes,
                                            class_weight=dataset.class_weight if hasattr(dataset, "class_weight") and \
                                                                                 hparams.use_class_weights else None,
                                            loss_type=hparams.loss_type,
                                            multilabel=dataset.multilabel)
        self.hparams.n_params = self.get_n_params()
    def __init__(self,
                 hparams,
                 dataset: HeteroNetDataset,
                 metrics=["accuracy"],
                 collate_fn="neighbor_sampler") -> None:
        super(LATTENodeClassifier, self).__init__(hparams=hparams,
                                                  dataset=dataset,
                                                  metrics=metrics)
        self.head_node_type = dataset.head_node_type
        self.dataset = dataset
        self.multilabel = dataset.multilabel
        self.y_types = list(dataset.y_dict.keys())
        self._name = f"LATTE-{hparams.t_order}{' proximity' if hparams.use_proximity else ''}"
        self.collate_fn = collate_fn

        self.latte = LATTE(t_order=hparams.t_order,
                           embedding_dim=hparams.embedding_dim,
                           in_channels_dict=dataset.node_attr_shape,
                           num_nodes_dict=dataset.num_nodes_dict,
                           metapaths=dataset.get_metapaths(),
                           activation=hparams.activation,
                           attn_heads=hparams.attn_heads,
                           attn_activation=hparams.attn_activation,
                           attn_dropout=hparams.attn_dropout,
                           use_proximity=hparams.use_proximity,
                           neg_sampling_ratio=hparams.neg_sampling_ratio)
        hparams.embedding_dim = hparams.embedding_dim * hparams.t_order

        self.classifier = DenseClassification(hparams)
        # self.classifier = MulticlassClassification(num_feature=hparams.embedding_dim,
        #                                            num_class=hparams.n_classes,
        #                                            loss_type=hparams.loss_type)
        self.criterion = ClassificationLoss(n_classes=dataset.n_classes,
                                            class_weight=dataset.class_weight if hasattr(dataset, "class_weight") and \
                                                                                 hparams.use_class_weights else None,
                                            loss_type=hparams.loss_type,
                                            multilabel=dataset.multilabel)
        self.hparams.n_params = self.get_n_params()
class HGTNodeClassifier(NodeClfMetrics):
    def __init__(self, hparams, dataset: DGLNodeSampler, metrics=["accuracy"], collate_fn="neighbor_sampler") -> None:
        super(HGTNodeClassifier, self).__init__(hparams=hparams, dataset=dataset, metrics=metrics)
        self.head_node_type = dataset.head_node_type
        self.dataset = dataset
        self.multilabel = dataset.multilabel
        self.y_types = list(dataset.y_dict.keys())
        self._name = f"LATTE-{hparams.t_order}{' proximity' if hparams.use_proximity else ''}"
        self.collate_fn = collate_fn

        # self.latte = LATTE(t_order=hparams.t_order, embedding_dim=hparams.embedding_dim,
        #                    in_channels_dict=dataset.node_attr_shape, num_nodes_dict=dataset.num_nodes_dict,
        #                    metapaths=dataset.get_metapaths(), activation=hparams.activation,
        #                    attn_heads=hparams.attn_heads, attn_activation=hparams.attn_activation,
        #                    attn_dropout=hparams.attn_dropout, use_proximity=hparams.use_proximity,
        #                    neg_sampling_ratio=hparams.neg_sampling_ratio)
        # hparams.embedding_dim = hparams.embedding_dim * hparams.t_order

        self.embedder = HGT(node_dict={ntype: i for i, ntype in enumerate(dataset.node_types)},
                            edge_dict={metapath[1]: i for i, metapath in enumerate(dataset.get_metapaths())},
                            n_inp=self.dataset.node_attr_shape[self.head_node_type],
                            n_hid=hparams.embedding_dim, n_out=hparams.embedding_dim,
                            n_layers=len(self.dataset.neighbor_sizes),
                            n_heads=hparams.attn_heads)

        self.classifier = DenseClassification(hparams)

        self.criterion = ClassificationLoss(n_classes=dataset.n_classes,
                                            class_weight=dataset.class_weight if hasattr(dataset, "class_weight") and \
                                                                                 hparams.use_class_weights else None,
                                            loss_type=hparams.loss_type,
                                            multilabel=dataset.multilabel)
        self.hparams.n_params = self.get_n_params()

    def forward(self, blocks, batch_inputs: dict, **kwargs):
        embeddings = self.embedder.forward(blocks, batch_inputs)

        y_hat = self.classifier.forward(embeddings[self.head_node_type])
        return y_hat

    def training_step(self, batch, batch_nb):
        input_nodes, seeds, blocks = batch
        batch_inputs = blocks[0].srcdata['feat']
        batch_labels = blocks[-1].dstdata['labels'][self.head_node_type]

        y_hat = self.forward(blocks, batch_inputs)
        loss = self.criterion.forward(y_hat, batch_labels)

        self.train_metrics.update_metrics(y_hat, batch_labels, weights=None)

        logs = None

        outputs = {'loss': loss}
        if logs is not None:
            outputs.update({'progress_bar': logs, "logs": logs})
        return outputs

    def validation_step(self, batch, batch_nb):
        input_nodes, seeds, blocks = batch
        batch_inputs = blocks[0].srcdata['feat']
        batch_labels = blocks[-1].dstdata['labels'][self.head_node_type]

        y_hat = self.forward(blocks, batch_inputs)

        val_loss = self.criterion.forward(y_hat, batch_labels)

        self.valid_metrics.update_metrics(y_hat, batch_labels, weights=None)

        return {"val_loss": val_loss}

    def test_step(self, batch, batch_nb):
        input_nodes, seeds, blocks = batch
        batch_inputs = blocks[0].srcdata['feat']
        batch_labels = blocks[-1].dstdata['labels'][self.head_node_type]

        y_hat = self.forward(blocks, batch_inputs, save_betas=True)

        test_loss = self.criterion.forward(y_hat, batch_labels)

        if batch_nb == 0:
            self.print_pred_class_counts(y_hat, batch_labels, multilabel=self.dataset.multilabel)

        self.test_metrics.update_metrics(y_hat, batch_labels, weights=None)

        return {"test_loss": test_loss}

    def train_dataloader(self):
        return self.dataset.train_dataloader(collate_fn=None,
                                             batch_size=self.hparams.batch_size,
                                             num_workers=int(0.4 * multiprocessing.cpu_count()))

    def val_dataloader(self, batch_size=None):
        return self.dataset.valid_dataloader(collate_fn=None,
                                             batch_size=self.hparams.batch_size,
                                             num_workers=max(1, int(0.1 * multiprocessing.cpu_count())))

    def valtrain_dataloader(self):
        return self.dataset.valtrain_dataloader(collate_fn=None,
                                                batch_size=self.hparams.batch_size,
                                                num_workers=max(1, int(0.1 * multiprocessing.cpu_count())))

    def test_dataloader(self, batch_size=None):
        return self.dataset.test_dataloader(collate_fn=None,
                                            batch_size=self.hparams.batch_size,
                                            num_workers=max(1, int(0.1 * multiprocessing.cpu_count())))

    def configure_optimizers(self):
        param_optimizer = list(self.named_parameters())
        no_decay = ['bias', 'alpha_activation']
        optimizer_grouped_parameters = [
            {'params': [p for name, p in param_optimizer if not any(key in name for key in no_decay)],
             'weight_decay': 0.01},
            {'params': [p for name, p in param_optimizer if any(key in name for key in no_decay)], 'weight_decay': 0.0}
        ]

        # optimizer = torch.optim.AdamW(optimizer_grouped_parameters, eps=1e-06, lr=self.hparams.lr)

        optimizer = torch.optim.Adam(optimizer_grouped_parameters,
                                     lr=self.hparams.lr,  # momentum=self.hparams.momentum,
                                     weight_decay=self.hparams.weight_decay)
        scheduler = ReduceLROnPlateau(optimizer)

        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
    def __init__(self, hparams):
        torch.nn.Module.__init__(self)

        assert isinstance(
            hparams.encoder, dict
        ), "hparams.encoder must be a dict. If not multi node types, use MonoplexEmbedder instead."
        assert isinstance(
            hparams.embedder, dict
        ), "hparams.embedder must be a dict. If not multi-layer, use MonoplexEmbedder instead."
        self.hparams = copy.copy(hparams)

        ################### Encoding ####################
        self.node_types = list(hparams.encoder.keys())
        for node_type, encoder in hparams.encoder.items():
            if encoder == "ConvLSTM":
                hparams.vocab_size = self.hparams.vocab_size[node_type]
                self.set_encoder(node_type, ConvLSTM(hparams))

            elif encoder == "Albert":
                config = AlbertConfig(
                    vocab_size=hparams.vocab_size,
                    embedding_size=hparams.word_embedding_size,
                    hidden_size=hparams.encoding_dim,
                    num_hidden_layers=hparams.num_hidden_layers,
                    num_hidden_groups=hparams.num_hidden_groups,
                    hidden_dropout_prob=hparams.hidden_dropout_prob,
                    attention_probs_dropout_prob=hparams.
                    attention_probs_dropout_prob,
                    num_attention_heads=hparams.num_attention_heads,
                    intermediate_size=hparams.intermediate_size,
                    type_vocab_size=1,
                    max_position_embeddings=hparams.max_length,
                )
                self.set_encoder(node_type, AlbertEncoder(config))

            elif "NodeIDEmbedding" in encoder:
                # `encoder` is a dict with {"NodeIDEmbedding": hparams}
                self.set_encoder(
                    node_type,
                    NodeIDEmbedding(hparams=encoder["NodeIDEmbedding"]))

            elif "Linear" in encoder:
                encoder_hparams = encoder["Linear"]
                self.set_encoder(
                    node_type,
                    torch.nn.Linear(in_features=encoder_hparams["in_features"],
                                    out_features=hparams.encoding_dim))

            else:
                raise Exception(
                    "hparams.encoder must be one of {'ConvLSTM', 'Albert', 'NodeIDEmbedding'}"
                )

        ################### Layer-specfic Embedding ####################
        self.layers = list(hparams.embedder)
        if hparams.multiplex_embedder == "ExpandedMultiplexGAT":
            self._embedder = ExpandedMultiplexGAT(
                in_channels=hparams.encoding_dim,
                out_channels=int(hparams.embedding_dim / len(self.node_types)),
                node_types=self.node_types,
                layers=self.layers,
                dropout=hparams.nb_attn_dropout)
        else:
            print(
                '"multiplex_embedder" used. Concatenate multi-layer embeddings instead.'
            )

        ################### Classifier ####################
        if hparams.classifier == "Dense":
            self._classifier = DenseClassification(hparams)
        elif hparams.classifier == "HierarchicalAWX":
            self._classifier = HierarchicalAWX(hparams)
        else:
            raise Exception("hparams.classifier must be one of {'Dense'}")

        if hparams.use_hierar:
            label_map = pd.Series(range(len(hparams.classes)),
                                  index=hparams.classes).to_dict()
            hierar_relations = get_hierar_relations(
                hparams.hierar_taxonomy_file, label_map=label_map)

        self.criterion = ClassificationLoss(
            n_classes=hparams.n_classes,
            class_weight=None if not hasattr(hparams, "class_weight") else
            torch.tensor(hparams.class_weight),
            loss_type=hparams.loss_type,
            hierar_penalty=hparams.hierar_penalty
            if hparams.use_hierar else None,
            hierar_relations=hierar_relations if hparams.use_hierar else None)
    def __init__(self, hparams):
        torch.nn.Module.__init__(self)

        assert isinstance(
            hparams.encoder, dict
        ), "hparams.encoder must be a dict. If not multi node types, use MonoplexEmbedder instead."
        assert isinstance(
            hparams.embedder, dict
        ), "hparams.embedder must be a dict. If not multi-layer, use MonoplexEmbedder instead."
        self.hparams = hparams

        ################### Encoding ####################
        self.node_types = list(hparams.encoder.keys())
        for node_type, encoder in hparams.encoder.items():
            if encoder == "ConvLSTM":
                assert not (len(hparams.encoder) > 1
                            and not len(hparams.vocab_size) > 1)
                self.set_encoder(node_type, ConvLSTM(hparams))

            elif encoder == "Albert":
                assert not (len(hparams.encoder) > 1
                            and not len(hparams.vocab_size) > 1)
                config = AlbertConfig(
                    vocab_size=hparams.vocab_size,
                    embedding_size=hparams.word_embedding_size,
                    hidden_size=hparams.encoding_dim,
                    num_hidden_layers=hparams.num_hidden_layers,
                    num_hidden_groups=hparams.num_hidden_groups,
                    hidden_dropout_prob=hparams.hidden_dropout_prob,
                    attention_probs_dropout_prob=hparams.
                    attention_probs_dropout_prob,
                    num_attention_heads=hparams.num_attention_heads,
                    intermediate_size=hparams.intermediate_size,
                    type_vocab_size=1,
                    max_position_embeddings=hparams.max_length,
                )
                self.set_encoder(node_type, AlbertEncoder(config))

            elif "NodeIDEmbedding" in encoder:
                # `encoder` is a dict with {"NodeIDEmbedding": hparams}
                self.set_encoder(
                    node_type,
                    NodeIDEmbedding(hparams=encoder["NodeIDEmbedding"]))
            elif "Linear" in encoder:
                encoder_hparams = encoder["Linear"]
                self.set_encoder(
                    node_type,
                    torch.nn.Linear(in_features=encoder_hparams["in_features"],
                                    out_features=hparams.encoding_dim))

            else:
                raise Exception(
                    "hparams.encoder must be one of {'ConvLSTM', 'Albert', 'NodeIDEmbedding'}"
                )

        ################### Layer-specfic Embedding ####################
        for subnetwork_type, embedder_model in hparams.embedder.items():
            if embedder_model == "GAT":
                self.set_embedder(subnetwork_type, GAT(hparams))
            elif embedder_model == "GCN":
                self.set_embedder(subnetwork_type, GCN(hparams))
            elif embedder_model == "GraphSAGE":
                self.set_embedder(subnetwork_type, GraphSAGE(hparams))
            else:
                raise Exception(
                    f"Embedder model for hparams.embedder[{subnetwork_type}]] must be one of ['GAT', 'GCN', 'GraphSAGE']"
                )

        ################### Multiplex Embedding ####################
        layers = list(hparams.embedder.keys())
        self.layers = layers
        if hparams.multiplex_embedder == "MultiplexLayerAttention":
            self._multiplex_embedder = MultiplexLayerAttention(
                embedding_dim=hparams.embedding_dim,
                hidden_dim=hparams.multiplex_hidden_dim,
                attention_dropout=hparams.multiplex_attn_dropout,
                layers=layers)
            hparams.embedding_dim = hparams.multiplex_hidden_dim
        elif hparams.multiplex_embedder == "MultiplexNodeAttention":
            self._multiplex_embedder = MultiplexNodeAttention(
                embedding_dim=hparams.embedding_dim,
                hidden_dim=hparams.multiplex_hidden_dim,
                attention_dropout=hparams.multiplex_attn_dropout,
                layers=layers)
            hparams.embedding_dim = hparams.multiplex_hidden_dim
        else:
            print(
                '"multiplex_embedder" not used. Concatenate multi-layer embeddings instead.'
            )
            hparams.embedding_dim = hparams.embedding_dim * len(
                hparams.embedder)

        ################### Classifier ####################
        if hparams.classifier == "Dense":
            self._classifier = DenseClassification(hparams)
        elif hparams.classifier == "HierarchicalAWX":
            self._classifier = HierarchicalAWX(hparams)
        else:
            raise Exception("hparams.classifier must be one of {'Dense'}")

        if hparams.use_hierar:
            label_map = pd.Series(range(len(hparams.classes)),
                                  index=hparams.classes).to_dict()
            hierar_relations = get_hierar_relations(
                hparams.hierar_taxonomy_file, label_map=label_map)

        self.criterion = ClassificationLoss(
            n_classes=hparams.n_classes,
            class_weight=None if not hasattr(hparams, "class_weight") else
            torch.tensor(hparams.class_weight),
            loss_type=hparams.loss_type,
            hierar_penalty=hparams.hierar_penalty
            if hparams.use_hierar else None,
            hierar_relations=hierar_relations if hparams.use_hierar else None)
class LATTENodeClassifier(NodeClfMetrics):
    def __init__(self,
                 hparams,
                 dataset: HeteroNetDataset,
                 metrics=["accuracy"],
                 collate_fn="neighbor_sampler") -> None:
        super(LATTENodeClassifier, self).__init__(hparams=hparams,
                                                  dataset=dataset,
                                                  metrics=metrics)
        self.head_node_type = dataset.head_node_type
        self.dataset = dataset
        self.multilabel = dataset.multilabel
        self.y_types = list(dataset.y_dict.keys())
        self._name = f"LATTE-{hparams.t_order}{' proximity' if hparams.use_proximity else ''}"
        self.collate_fn = collate_fn

        self.latte = LATTE(t_order=hparams.t_order,
                           embedding_dim=hparams.embedding_dim,
                           in_channels_dict=dataset.node_attr_shape,
                           num_nodes_dict=dataset.num_nodes_dict,
                           metapaths=dataset.get_metapaths(),
                           activation=hparams.activation,
                           attn_heads=hparams.attn_heads,
                           attn_activation=hparams.attn_activation,
                           attn_dropout=hparams.attn_dropout,
                           use_proximity=hparams.use_proximity,
                           neg_sampling_ratio=hparams.neg_sampling_ratio)
        hparams.embedding_dim = hparams.embedding_dim * hparams.t_order

        self.classifier = DenseClassification(hparams)
        # self.classifier = MulticlassClassification(num_feature=hparams.embedding_dim,
        #                                            num_class=hparams.n_classes,
        #                                            loss_type=hparams.loss_type)
        self.criterion = ClassificationLoss(n_classes=dataset.n_classes,
                                            class_weight=dataset.class_weight if hasattr(dataset, "class_weight") and \
                                                                                 hparams.use_class_weights else None,
                                            loss_type=hparams.loss_type,
                                            multilabel=dataset.multilabel)
        self.hparams.n_params = self.get_n_params()

    def forward(self, input: dict, **kwargs):
        embeddings, proximity_loss, _ = self.latte.forward(
            X=input["x_dict"],
            edge_index_dict=input["edge_index_dict"],
            global_node_idx=input["global_node_index"],
            **kwargs)
        y_hat = self.classifier.forward(embeddings[self.head_node_type])
        return y_hat, proximity_loss

    def training_step(self, batch, batch_nb):
        X, y, weights = batch
        y_hat, proximity_loss = self.forward(X)

        if isinstance(y, dict) and len(y) > 1:
            y = y[self.head_node_type]

        y_hat, y = filter_samples(Y_hat=y_hat, Y=y, weights=weights)
        loss = self.criterion.forward(y_hat, y)

        self.train_metrics.update_metrics(y_hat, y, weights=None)

        logs = None
        if self.hparams.use_proximity:
            loss = loss + proximity_loss
            logs = {"proximity_loss": proximity_loss}

        outputs = {'loss': loss}
        if logs is not None:
            outputs.update({'progress_bar': logs, "logs": logs})
        return outputs

    def validation_step(self, batch, batch_nb):
        X, y, weights = batch
        y_hat, proximity_loss = self.forward(X)

        if isinstance(y, dict) and len(y) > 1:
            y = y[self.head_node_type]

        y_hat, y = filter_samples(Y_hat=y_hat, Y=y, weights=weights)
        val_loss = self.criterion.forward(y_hat, y)
        # if batch_nb == 0:
        #     self.print_pred_class_counts(y_hat, y, multilabel=self.dataset.multilabel)

        self.valid_metrics.update_metrics(y_hat, y, weights=None)

        if self.hparams.use_proximity:
            val_loss = val_loss + proximity_loss

        return {"val_loss": val_loss}

    def test_step(self, batch, batch_nb):
        X, y, weights = batch
        y_hat, proximity_loss = self.forward(X, save_betas=True)
        if isinstance(y, dict) and len(y) > 1:
            y = y[self.head_node_type]
        y_hat, y = filter_samples(Y_hat=y_hat, Y=y, weights=weights)
        test_loss = self.criterion(y_hat, y)

        if batch_nb == 0:
            self.print_pred_class_counts(y_hat,
                                         y,
                                         multilabel=self.dataset.multilabel)

        self.test_metrics.update_metrics(y_hat, y, weights=None)

        if self.hparams.use_proximity:
            test_loss = test_loss + proximity_loss

        return {"test_loss": test_loss}

    def train_dataloader(self):
        return self.dataset.train_dataloader(
            collate_fn=self.collate_fn,
            batch_size=self.hparams.batch_size,
            num_workers=int(0.4 * multiprocessing.cpu_count()))

    def val_dataloader(self, batch_size=None):
        return self.dataset.valid_dataloader(
            collate_fn=self.collate_fn,
            batch_size=self.hparams.batch_size,
            num_workers=max(1, int(0.1 * multiprocessing.cpu_count())))

    def valtrain_dataloader(self):
        return self.dataset.valtrain_dataloader(
            collate_fn=self.collate_fn,
            batch_size=self.hparams.batch_size,
            num_workers=max(1, int(0.1 * multiprocessing.cpu_count())))

    def test_dataloader(self, batch_size=None):
        return self.dataset.test_dataloader(
            collate_fn=self.collate_fn,
            batch_size=self.hparams.batch_size,
            num_workers=max(1, int(0.1 * multiprocessing.cpu_count())))

    def configure_optimizers(self):
        param_optimizer = list(self.named_parameters())
        no_decay = ['bias', 'alpha_activation']
        optimizer_grouped_parameters = [{
            'params': [
                p for name, p in param_optimizer
                if not any(key in name for key in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params': [
                p for name, p in param_optimizer
                if any(key in name for key in no_decay)
            ],
            'weight_decay':
            0.0
        }]

        # optimizer = torch.optim.AdamW(optimizer_grouped_parameters, eps=1e-06, lr=self.hparams.lr)

        optimizer = torch.optim.Adam(
            optimizer_grouped_parameters,
            lr=self.hparams.lr,  # momentum=self.hparams.momentum,
            weight_decay=self.hparams.weight_decay)
        scheduler = ReduceLROnPlateau(optimizer)

        return [optimizer], [scheduler]