def compute_calibration(
        self,
        sequences_list: List[str],
        batch_size: int = 1,
        pass_mode: str = "forward",
        tokens_list: List[str] = None,
        n_bins: int = 10,
    ) -> Dict[str, Any]:
        """Compute model calibration from the input sequences

        Args:
            sequences_list ([type]): [description]
            batch_size ([type], optional): [description]. Defaults to 1.
            pass_mode ([type], optional): [description]. Defaults to "forward".
            tokens_list ([type], optional): [description]. Defaults to None.
            n_bins ([type], optional): [description]. Defaults to 10.

        Returns:
            [type]: [description]
        """
        if tokens_list is None:
            tokens_list = NATURAL_AAS_LIST

        _check_sequence(sequences_list, self.model_dir, 1024)

        inputs, labels, tokens = self._process_sequences_and_tokens(
            sequences_list, tokens_list)
        logits = self._compute_logits(inputs, batch_size, pass_mode)
        logits, labels = self._filter_logits(logits, labels, tokens)
        calibration_dict = self._compute_calibration(logits, labels, n_bins)

        return calibration_dict
    def compute_loglikelihood(
        self,
        sequences_list: List[str],
        batch_size: int = 1,
        tokens_list: List[str] = None,
        pass_mode: str = "forward",
    ) -> np.ndarray:
        """Function that computes loglikelihoods of sequences

        Args:
            sequences_list: List of sequences
            batch_size: Batch size
            pass_mode: Mode of model evaluation ('forward' or 'masked')
            tokens_list: List of tokens to consider

        Returns:
            torch.Tensor: loglikelihoods in torch.tensor format
        """
        if tokens_list is None:
            tokens_list = NATURAL_AAS_LIST

        _check_sequence(sequences_list, self.model_dir, 1024)
        _check_memory_logits(sequences_list, self.vocab_size, pass_mode)

        inputs, labels, tokens = self._process_sequences_and_tokens(
            sequences_list, tokens_list)
        logits = self._compute_logits(inputs, batch_size, pass_mode)
        loglikelihoods = self._filter_loglikelihoods(logits, labels, tokens)

        return loglikelihoods.numpy()
    def compute_accuracy(
        self,
        sequences_list: List[str],
        batch_size: int = 1,
        pass_mode: str = "forward",
        tokens_list: List[str] = None,
    ) -> float:
        """Compute model accuracy from the input sequences

        Args:
            sequences_list: [description]
            batch_size: [description]. Defaults to 1.
            pass_mode: [description]. Defaults to "forward".
            tokens_list: [description]. Defaults to None.

        Returns:
            [type]: [description]
        """
        if tokens_list is None:
            tokens_list = NATURAL_AAS_LIST

        _check_sequence(sequences_list, self.model_dir, 1024)

        inputs, labels, tokens = self._process_sequences_and_tokens(
            sequences_list, tokens_list)
        logits = self._compute_logits(inputs, batch_size, pass_mode)
        logits, labels = self._filter_logits(logits, labels, tokens)
        accuracy = self._compute_accuracy(logits, labels)

        return accuracy
    def compute_embeddings(
        self,
        sequences: Union[List[str], str],
        batch_size: int = 1,
        pool_mode: Tuple[str, ...] = ("cls", "mean"),
        tokens_list: List[str] = None,
        silent: bool = False,
    ) -> Dict[str, np.ndarray]:
        """Function that computes embeddings of sequences.

        The embedding has a size (n_sequence, num_tokens, embeddings_size) so we use
        an aggregation function specified in pool_mode to aggregate the tensor on
        the num_tokens dimension. 'mean' signifies that we take the mean over the
        num_tokens dimension.

        Args:
            sequences: List of sequences or path of fasta file
            batch_size: Batch size
            pool_mode: Mode of pooling ('cls', 'mean', 'min', 'max)
            tokens_list: List of tokens to consider
            silent : whereas to display or not progress bar
        Returns:
            torch.Tensor: Tensor of shape [number_of_sequences, embeddings_size]
        """
        if "full" in pool_mode and not all(
                len(s) == len(sequences[0]) for s in sequences):
            raise Exception(
                'Sequences must be of same length when pool_mode = ("full",)')

        if tokens_list is None:
            tokens_list = NATURAL_AAS_LIST

        if isinstance(sequences, str):
            sequences = load_fasta(sequences)

        _check_sequence(sequences, self.model_dir, 1024)
        _check_memory_embeddings(sequences, self.embeddings_size, pool_mode)

        inputs, _, tokens = self._process_sequences_and_tokens(
            sequences, tokens_list)
        embeddings_dict = dict(
            zip(pool_mode, [torch.Tensor()] * len(pool_mode)))

        for batch_inputs in tqdm(
                self._generate_chunks(inputs, batch_size),
                total=self._get_num_batch_iter(inputs, batch_size),
                disable=silent,
        ):
            _, batch_embeddings = self._model_pass(batch_inputs)
            batch_labels = batch_inputs["input_ids"]

            batch_embeddings_dict = self._filter_and_pool_embeddings(
                batch_embeddings, batch_labels, tokens, pool_mode)

            for key in pool_mode:
                embeddings_dict[key] = torch.cat(
                    (embeddings_dict[key], batch_embeddings_dict[key]), dim=0)

        return {key: value.numpy() for key, value in embeddings_dict.items()}
    def compute_embeddings(
        self,
        sequences_list: List[str],
        batch_size: int = 1,
        pool_mode: Tuple[str, ...] = ("cls", "mean"),
        tokens_list: List[str] = None,
    ) -> Dict[str, np.ndarray]:
        """Function that computes embeddings of sequences

        Args:
            sequences_list: List of sequences
            batch_size: Batch size
            pool_mode: Mode of pooling ('cls', 'mean', 'min', 'max)
            tokens_list: List of tokens to consider

        Returns:
            torch.Tensor: Tensor of shape [number_of_sequences, embeddings_size]
        """
        if "full" in pool_mode and not all(
                len(s) == len(sequences_list[0]) for s in sequences_list):
            raise Exception(
                "Sequences must be of same length when pool_mode = (\"full\",)"
            )

        if tokens_list is None:
            tokens_list = NATURAL_AAS_LIST

        _check_sequence(sequences_list, self.model_dir, 1024)
        _check_memory_embeddings(sequences_list, self.embeddings_size,
                                 pool_mode)

        inputs, _, tokens = self._process_sequences_and_tokens(
            sequences_list, tokens_list)
        embeddings_dict = dict(
            zip(pool_mode, [torch.Tensor()] * len(pool_mode)))

        for batch_inputs in tqdm(
                self._generate_chunks(inputs, batch_size),
                total=self._get_num_batch_iter(inputs, batch_size),
        ):
            _, batch_embeddings = self._model_pass(batch_inputs)
            batch_labels = batch_inputs["input_ids"]

            batch_embeddings_dict = self._filter_and_pool_embeddings(
                batch_embeddings, batch_labels, tokens, pool_mode)

            for key in pool_mode:
                embeddings_dict[key] = torch.cat(
                    (embeddings_dict[key], batch_embeddings_dict[key]), dim=0)

        return {key: value.numpy() for key, value in embeddings_dict.items()}
    def compute_logits(
        self,
        sequences: Union[List[str], str],
        batch_size: int = 1,
        tokens_list: List[str] = None,
        pass_mode: str = "forward",
        silent: bool = False,
    ) -> Tuple[List[np.ndarray]]:
        """Function that computes the logits from sequences.

        It returns a list of logits for each sequence. Each sequence in the list
        contains only the amino acid to interest.

        Args:
            sequences_list: List of sequences
            batch_size: number of sequences to consider for the forward pass
            pass_mode: Mode of model evaluation ('forward' or 'masked')
            tokens_list: List of tokens to consider

        Returns:
            Tuple[torch.tensor, torch.tensor]: logits and labels in torch.tensor format
        """
        if tokens_list is None:
            tokens_list = NATURAL_AAS_LIST

        if isinstance(sequences, str):
            sequences = load_fasta(sequences)

        _check_sequence(sequences, self.model_dir, 1024)
        _check_memory_logits(sequences, self.vocab_size, pass_mode)

        inputs, labels, tokens = self._process_sequences_and_tokens(
            sequences, tokens_list)
        logits = self._compute_logits(inputs,
                                      batch_size,
                                      pass_mode,
                                      silent=silent)
        logits, labels = self._filter_logits(logits, labels, tokens)

        lengths = [len(sequence) for sequence in sequences]
        splitted_logits = torch.split(logits, lengths, dim=0)
        splitted_logits = [logits.numpy() for logits in splitted_logits]

        return splitted_logits
    def compute_calibration(
        self,
        sequences: Union[List[str], str],
        batch_size: int = 1,
        pass_mode: str = "forward",
        tokens_list: Optional[List[str]] = None,
        n_bins: int = 10,
        silent: bool = False,
    ) -> Dict[str, Any]:
        """Compute model calibration from the input sequences

        Args:
            sequences_list : [description]
            batch_size : [description]. Defaults to 1.
            pass_mode : [description]. Defaults to "forward".
            tokens_list : [description]. Defaults to None.
            n_bins : [description]. Defaults to 10.
            silent: display or not progress bar
        Returns:
            [Dict]: [description]
        """
        if tokens_list is None:
            tokens_list = NATURAL_AAS_LIST

        if isinstance(sequences, str):
            sequences = load_fasta(sequences)

        _check_sequence(sequences, self.model_dir, 1024)

        inputs, labels, tokens = self._process_sequences_and_tokens(
            sequences, tokens_list)
        logits = self._compute_logits(inputs,
                                      batch_size,
                                      pass_mode,
                                      silent=silent)
        logits, labels = self._filter_logits(logits, labels, tokens)
        calibration_dict = self._compute_calibration(logits, labels, n_bins)

        return calibration_dict
    def compute_accuracy(
        self,
        sequences: Union[List[str], str],
        batch_size: int = 1,
        pass_mode: str = "forward",
        tokens_list: List[str] = None,
        silent: bool = False,
    ) -> float:
        """Compute model accuracy from the input sequences

        Args:
            sequences (Union[List[str],str]): list of sequence or fasta file
            batch_size ([type], optional): [description]. Defaults to 1.
            pass_mode ([type], optional): [description]. Defaults to "forward".
            tokens_list ([type], optional): [description]. Defaults to None.

        Returns:
            [type]: [description]
        """
        if tokens_list is None:
            tokens_list = NATURAL_AAS_LIST

        if isinstance(sequences, str):
            sequences = load_fasta(sequences)
        _check_sequence(sequences, self.model_dir, 1024)

        inputs, labels, tokens = self._process_sequences_and_tokens(
            sequences, tokens_list)
        logits = self._compute_logits(inputs,
                                      batch_size,
                                      pass_mode,
                                      silent=silent)
        logits, labels = self._filter_logits(logits, labels, tokens)
        accuracy = self._compute_accuracy(logits, labels)

        return accuracy
    def train_masked(
        self,
        train_sequences: Union[List[str], str],
        lr: float = 1.0e-5,
        warmup_updates: int = 1024,
        warmup_init_lr: float = 1e-7,
        epochs: int = 10,
        batch_size: int = 2,
        acc_batch_size: int = 256,
        masking_ratio: float = 0.025,
        masking_prob: float = 0.8,
        random_token_prob: float = 0.15,
        toks_per_batch: int = 2048,
        filter_len=1024,
        accelerator: str = "ddp",
        amp_level: str = "O2",
        precision: int = 16,
        logs_save_dir: str = "logs",
        logs_name_exp: str = "finetune_masked",
        checkpoint: Optional[str] = None,
        save_last_checkpoint: bool = True,
    ):
        """Function to finetune a model on a specific dataset

        This function will finetune the choosen model on a dataset of
        sequences with pytorch ligthening. You can modify the masking ratio of AA
        in the arguments for better convergence.
        Be careful with the accelerator that you use. DDP accelerator will
        launch multiple python process and do not be use in a notebook.

        More informations on GPU/accelerator compatibility here :
            https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html
        The wisest choice would be to use DDP for multi-gpu training.

        Args:
            train_sequences : Could be a list of sequences or the path of a
                              fasta file with multiple seqRecords
            lr : learning rate for training phase. Defaults to 1.0e-5.
            warmup_updates : Number of warming updates, number of step while increasing
            the leraning rate. Defaults to 1024.
            warmup_init_lr :  Initial lr for warming_update. Defaults to 1e-7.
            epochs :  number of epoch for training. Defaults to 10.
            batch_size :  number of sequence to consider in a batch. Defaults to 2.
            acc_batch_size : accumulated batch size Defaults to 2048.
            masking_ratio : ratio of tokens to be masked. Defaults to 0.025.
            masking_prob :  probability that the chose token is replaced with a mask token.
                            Defaults to 0.8.
            random_token_prob : probability that the chose token is replaced with a random token.
                                Defaults to 0.1.
            toks_per_batch: Maximum number of token to consider in a batch.Defaults to 2048.
                            This argument will set the number of sequences in a batch, which
                            is dynamically computed. Batch size use accumulate_grad_batches to compute
                            accumulate_grad_batches parameter.
            extra_toks_per_seq: Defaults to 2,
            filter_len : Size of sequence to filter. Defaults to 1024. (NOT USED)
            accelerator: type of accelerator for mutli-gpu processing (DPP recommanded)
            amp_level: allow mixed precision. Defaults to '02'
            precision: reducing precision allows to decrease the GPU memory needed.
                       Defaults to 16 (float16)
            logs_save_dir : Defaults directory to logs.
            logs_name_exp: Name of the experience in the logs.
            checkpoint : Path to a checkpoint file to restore training session.
            save_last_checkpoint: Save last checkpoint and 2 best trainings models
                                  to restore training session. Take a large amout of time and memory.
        """
        if isinstance(train_sequences, str):
            train_sequences = load_fasta(train_sequences)
        _check_sequence(train_sequences, self.model_dir, 1024)  # noqa: ignore

        fit_model = self.model.module if self.multi_gpu else self.model  # type: ignore
        alphabet = self._get_alphabet_dataloader()

        extra_toks_per_seq = int(alphabet.prepend_bos) + int(
            alphabet.append_eos)
        lightning_model = LightningModule(
            model=fit_model,
            alphabet=alphabet,
            lr=lr,
            warmup_updates=warmup_updates,
            warmup_init_lr=warmup_init_lr,
            warmup_end_lr=lr,
        )

        data_module = BioDataModule(
            train_sequences,
            alphabet,
            filter_len,
            batch_size,
            masking_ratio,
            masking_prob,
            random_token_prob,
            toks_per_batch,
            extra_toks_per_seq,
        )

        if torch.cuda.is_available():
            n_gpus = torch.cuda.device_count()
        else:
            log.warning("You try to train a transformers without GPU.")
            return

        logger = CSVLogger(logs_save_dir, name=logs_name_exp)
        checkpoint_callback = None

        if save_last_checkpoint:
            checkpoint_callback = [
                ModelCheckpoint(
                    save_last=True,
                    save_top_k=2,
                    mode="max",
                    monitor="val_acc",
                    every_n_val_epochs=3,
                )
            ]

        trainer = Trainer(
            gpus=n_gpus,
            amp_level=amp_level,
            precision=precision,
            accumulate_grad_batches=acc_batch_size // batch_size,
            max_epochs=epochs,
            logger=logger,
            accelerator=accelerator,
            replace_sampler_ddp=False,
            resume_from_checkpoint=checkpoint,
            callbacks=checkpoint_callback,
        )

        trainer.fit(lightning_model, data_module)

        save_path = str(Path(join(logs_save_dir, logs_name_exp)).resolve())
        if accelerator == "ddp":
            rank = os.environ.get("LOCAL_RANK", None)
            rank = int(rank) if rank is not None else None  # type: ignore
            if rank == 0:
                self.save_model(save_path, lightning_model)
        else:
            self.save_model(save_path, lightning_model)

        if self.multi_gpu:
            self.model = DataParallel(lightning_model.model).to(self._device)
        else:
            self.model = lightning_model.model.to(self._device)

        log.info("Training completed.")
    def compute_probabilities(
        self,
        sequences_list: List[str],
        batch_size: int = 1,
        tokens_list: List[str] = None,
        pass_mode: str = "forward",
        silent: bool = False,
    ) -> List[Dict[int, Dict[str, float]]]:
        """Function that computes the probabilities over amino-acids from sequences.

        It takes as inputs a list of sequences and returns a list of dictionaries.
        Each dictionary contains the probabilities over the natural amino-acids for each
        position in the sequence. The keys represent the positions (indexed
        starting with 0) and the values are dictionaries of probabilities over
        the natural amino-acids for this position.

        In these dictionaries, the keys are the amino-acids and the value
        the corresponding probabilities.

        Args:
            sequences_list: List of sequences
            batch_size: number of sequences to consider for the forward pass
            pass_mode: Mode of model evaluation ('forward' or 'masked')
            tokens_list: List of tokens to consider
            silent : display or not progress bar
        Returns:
            List[Dict[int, Dict[str, float]]]: dictionaries of probabilities per seq
        """
        if tokens_list is None:
            tokens_list = NATURAL_AAS_LIST

        _check_sequence(sequences_list, self.model_dir, 1024)
        _check_memory_logits(sequences_list, self.vocab_size, pass_mode)

        inputs, labels, tokens = self._process_sequences_and_tokens(
            sequences_list, tokens_list)
        logits = self._compute_logits(inputs,
                                      batch_size,
                                      pass_mode,
                                      silent=silent)
        logits, _ = self._filter_logits(logits, labels, tokens)

        lengths = [len(sequence) for sequence in sequences_list]
        splitted_logits = torch.split(logits, lengths, dim=0)

        softmax = torch.nn.Softmax(dim=-1)
        splitted_probabilities = [
            softmax(logits) for logits in splitted_logits
        ]

        def _get_probabilities_dict(probs: torch.Tensor) -> Dict[str, float]:
            return {
                aa: float(probs[i].cpu().numpy())
                for i, aa in enumerate(NATURAL_AAS_LIST)
            }

        probabilities = [{
            key: _get_probabilities_dict(value)
            for key, value in dict(enumerate(split)).items()
        } for split in splitted_probabilities]

        return probabilities