class LightningLongformerCLS(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.train_config = config

        self.roberta = RobertaForMaskedLM.from_pretrained('roberta-base')
        _ = self.roberta.eval()
        for param in self.roberta.parameters():
            param.requires_grad = False

        self.pred_model = self.roberta.roberta
        self.enc_model = self.pred_model.embeddings.word_embeddings

        # self.proj_head = DVProjectionHead()
        # self.proj_head = DVProjectionHead_ActiFirst()
        self.proj_head = DVProjectionHead_EmbActi()

        self.tkz = RobertaTokenizer.from_pretrained("roberta-base")
        self.collator = TokenizerCollate(self.tkz)

        self.lossfunc = nn.BCEWithLogitsLoss()

        self.acc = Accuracy(threshold=0.0)
        self.f1 = F1(threshold=0.0)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(),
                                      lr=self.train_config["learning_rate"])
        scheduler = transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
            optimizer,
            num_warmup_steps=10,
            num_training_steps=5000,
            num_cycles=10)
        schedulers = [{
            'scheduler': scheduler,
            'interval': 'step',
            'frequency': 1
        }]
        return [optimizer], schedulers

    def train_dataloader(self):
        # self.dataset_train = PANDataset('./data_pickle_cutcombo/pan_all_cls/train_kucombo_only.pickle')
        # self.dataset_train = PANDataset('./data_pickle_cutcombo/pan_14e_cls/train_essays.pickle')
        self.dataset_train = PANDataset(
            './data_pickle_cutcombo/pan_14n_cls/train_novels_kucombo_only.pickle'
        )
        self.loader_train = DataLoader(
            self.dataset_train,
            batch_size=self.train_config["batch_size"],
            collate_fn=self.collator,
            num_workers=4,
            pin_memory=True,
            drop_last=False,
            shuffle=False)
        return self.loader_train

    def val_dataloader(self):
        # self.dataset_val = PANDataset('./data_pickle_cutcombo/pan_14e_cls/test02_essays_onecut.pickle')
        self.dataset_val = PANDataset(
            './data_pickle_cutcombo/pan_14n_cls/test02_novels_onecut.pickle')
        self.loader_val = DataLoader(
            self.dataset_val,
            batch_size=self.train_config["batch_size"],
            collate_fn=self.collator,
            num_workers=4,
            pin_memory=True,
            drop_last=False,
            shuffle=False)
        return self.loader_val

    def test_dataloader(self):
        # self.dataset_test = PANDataset('./data_pickle_cutcombo/pan_14e_cls/test02_essays_onecut.pickle')
        self.dataset_test = PANDataset(
            './data_pickle_cutcombo/pan_14n_cls/test02_novels_onecut.pickle')
        self.loader_test = DataLoader(
            self.dataset_test,
            batch_size=self.train_config["batch_size"],
            collate_fn=self.collator,
            num_workers=4,
            pin_memory=True,
            drop_last=False,
            shuffle=False)
        return self.loader_test

    @autocast()
    def forward(self, inputs, onedoc_enc=False):
        def one_doc_embed(input_ids, input_mask, mask_n=1):
            uniq_mask = []
            uniq_input, inverse_indices = torch.unique(input_ids,
                                                       return_inverse=True,
                                                       dim=0)
            invi = inverse_indices.detach().cpu().numpy()
            for i in range(uniq_input.shape[0]):
                first_index = np.where(invi == i)[0][0]
                uniq_mask.append(input_mask[first_index, :])

            input_ids = uniq_input
            input_mask = torch.stack(uniq_mask, dim=0)

            embed = self.enc_model(input_ids)

            result_embed = []
            result_pred = []
            # skip start and end symbol
            masked_ids = input_ids.clone()
            for i in range(1, input_ids.shape[1] - mask_n):
                masked_ids[:, i:(i + mask_n)] = self.tkz.mask_token_id

                output = self.pred_model(input_ids=masked_ids,
                                         attention_mask=input_mask,
                                         return_dict=False)[0]
                result_embed.append(embed[:, i:(i + mask_n), :])
                result_pred.append(output[:, i:(i + mask_n), :])

                masked_ids[:, i:(i + mask_n)] = input_ids[:, i:(i + mask_n)]

            # stack along doc_len
            result_embed = torch.cat(result_embed, dim=1)
            result_pred = torch.cat(result_pred, dim=1)

            rec_embed = []
            rec_pred = []
            for i in invi:
                rec_embed.append(result_embed[i, :, :])
                rec_pred.append(result_pred[i, :, :])

            rec_embed = torch.stack(rec_embed, dim=0)
            rec_pred = torch.stack(rec_pred, dim=0)
            return rec_embed, rec_pred

        if onedoc_enc:
            doc_ids, doc_mask = inputs
            doc_embed, doc_pred = one_doc_embed(input_ids=doc_ids,
                                                input_mask=doc_mask)
            doc_dv = doc_pred - doc_embed
            return doc_pred, doc_embed, doc_dv
        else:
            labels, kno_ids, kno_mask, unk_ids, unk_mask = inputs

            kno_embed, kno_pred = one_doc_embed(input_ids=kno_ids,
                                                input_mask=kno_mask)
            unk_embed, unk_pred = one_doc_embed(input_ids=unk_ids,
                                                input_mask=unk_mask)

            kno_dv = kno_pred - kno_embed
            unk_dv = unk_pred - unk_embed

            # logits = self.proj_head(kno_dv, kno_mask[:,1:-1], unk_dv, unk_mask[:,1:-1])
            logits = self.proj_head(kno_embed, kno_dv, kno_mask[:, 1:-1],
                                    unk_embed, unk_dv, unk_mask[:, 1:-1])

            logits = torch.squeeze(logits)
            labels = labels.float()
            loss = self.lossfunc(logits, labels)

            return (loss, logits, (kno_embed, kno_pred, unk_embed, unk_pred))

    def training_step(self, batch, batch_idx):
        labels, kno_ids, kno_mask, unk_ids, unk_mask = batch

        loss, logits, outputs = self(
            (labels, kno_ids, kno_mask, unk_ids, unk_mask))

        self.log("train_loss", loss)
        self.log("logits mean", logits.mean())
        self.log("LR", self.trainer.optimizers[0].param_groups[0]['lr'])

        return loss

    def validation_step(self, batch, batch_idx):
        labels, kno_ids, kno_mask, unk_ids, unk_mask = batch

        loss, logits, outputs = self(
            (labels, kno_ids, kno_mask, unk_ids, unk_mask))

        self.acc(logits, labels.float())
        self.f1(logits, labels.float())

        return {"val_loss": loss}

    def validation_epoch_end(self, validation_step_outputs):
        avg_loss = torch.stack(
            [x['val_loss'] for x in validation_step_outputs]).mean()
        self.log("val_loss", avg_loss)
        self.log('eval accuracy', self.acc.compute())
        self.log('eval F1', self.f1.compute())
class LightningLongformerCLS(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.train_config = config

        self.roberta = RobertaForMaskedLM.from_pretrained('roberta-base')
        _ = self.roberta.eval()
        for param in self.roberta.parameters():
            param.requires_grad = False

        self.pred_model = self.roberta.roberta
        self.enc_model = self.pred_model.embeddings.word_embeddings
        self.proj_head = DVProjectionHead_EmbActi()

        self.lossfunc = nn.BCEWithLogitsLoss()

        self.acc = Accuracy(threshold=0.0)
        self.f1 = F1(threshold=0.0)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(),
                                      lr=self.train_config["learning_rate"])
        return optimizer
        # scheduler = transformers.get_cosine_with_hard_restarts_schedule_with_warmup(optimizer,
        #                                                                             num_warmup_steps=20,
        #                                                                             num_training_steps=10000,
        #                                                                             num_cycles=10)
        # schedulers = [
        # {
        #  'scheduler': scheduler,
        #  'interval': 'step',
        #  'frequency': 1
        # }]
        # return [optimizer], schedulers

    def train_dataloader(self):
        self.dataset_train = PANDatasetKUEP(
            "data_pickle_cutcombo/pan_13_cls/train_KUEP_combo.pt")
        # self.dataset_train = PANDatasetKUEP("data_pickle_cutcombo/pan_14e_cls/train_KUEP_combo.pt")
        # self.dataset_train = PANDatasetKUEP("data_pickle_cutcombo/pan_14n_cls/train_KUEP_combo110k.pt")
        # self.dataset_train = PANDatasetKUEP("data_pickle_cutcombo/pan_15_cls/train_KUEP_combo.pt")
        self.loader_train = DataLoader(
            self.dataset_train,
            batch_size=self.train_config["batch_size"],
            collate_fn=TokenizerCollateKUEP(),
            num_workers=1,
            pin_memory=True,
            drop_last=False,
            shuffle=True)
        return self.loader_train

    def val_dataloader(self):
        self.dataset_val = PANDatasetKUEP(
            "data_pickle_cutcombo/pan_13_cls/test02_KUEP.pt",
            # self.dataset_val = PANDatasetKUEP("data_pickle_cutcombo/pan_14e_cls/test02_essays_KUEP.pt",
            # self.dataset_val = PANDatasetKUEP("data_pickle_cutcombo/pan_14n_cls/test02_KUEP.pt",
            # self.dataset_val = PANDatasetKUEP("data_pickle_cutcombo/pan_15_cls/test_KUEP.pt",
            test_1cut=True)
        self.loader_val = DataLoader(
            self.dataset_val,
            batch_size=self.train_config["batch_size"],
            collate_fn=TokenizerCollateKUEP(),
            num_workers=1,
            pin_memory=True,
            drop_last=False,
            shuffle=False)
        return self.loader_val

    @autocast()
    def forward(self, inputs):
        labels, kno_ids, kno_mask, kno_embed, kno_pred, \
                unk_ids, unk_mask, unk_embed, unk_pred = inputs
        kno_dv = kno_pred - kno_embed
        unk_dv = unk_pred - unk_embed

        logits = self.proj_head(kno_embed, kno_dv, kno_mask[:, 1:-1],
                                unk_embed, unk_dv, unk_mask[:, 1:-1])

        logits = torch.squeeze(logits)
        labels = labels.float()
        loss = self.lossfunc(logits, labels)

        return (loss, logits, (kno_embed, kno_pred, unk_embed, unk_pred))

    def training_step(self, batch, batch_idx):
        labels, kno_ids, kno_mask, kno_embed, kno_pred, \
                unk_ids, unk_mask, unk_embed, unk_pred  = batch

        loss, logits, outputs = self(batch)

        self.log("train_loss", loss)
        self.log("logits mean", logits.mean())
        self.log("LR", self.trainer.optimizers[0].param_groups[0]['lr'])

        return loss

    def validation_step(self, batch, batch_idx):
        labels, kno_ids, kno_mask, kno_embed, kno_pred, \
                unk_ids, unk_mask, unk_embed, unk_pred  = batch

        loss, logits, outputs = self(batch)

        self.acc(logits, labels.float())
        self.f1(logits, labels.float())

        return {"val_loss": loss}

    def validation_epoch_end(self, validation_step_outputs):
        avg_loss = torch.stack(
            [x['val_loss'] for x in validation_step_outputs]).mean()
        self.log("val_loss", avg_loss)
        self.log('eval accuracy', self.acc.compute())
        self.log('eval F1', self.f1.compute())