コード例 #1
0
def convert_examples_to_features(
        examples: List[InputExample],
        tokenizer: PreTrainedTokenizer,
        slot_label_encoder: LabelEncoder,
        intent_label_encoder: LabelEncoder,
        max_length: Optional[int] = None
):
    if max_length is None:
        max_length = tokenizer.max_len

    # slot / intent label id generate.
    truncation_size = max_length // 2
    features = []
    for example in examples:
        tmp_label_slot_origin = slot_label_encoder.batch_encode(example.label_slot_raw).numpy().tolist()
        if example.label_intent_raw:
            tmp_label_intent_origin = intent_label_encoder.batch_encode(example.label_intent_raw).numpy().tolist()
        else:
            tmp_label_intent_origin = []
        # 手动对query/context进行截断,截断长度分别为`max_length // 2`
        tmp_text_query = example.text_query[:truncation_size]  # query向后截断
        tmp_label_slot_origin = tmp_label_slot_origin[:truncation_size]
        tmp_text_context = example.text_context[-truncation_size:]  # context向前截断
        # TODO: 可选query/context顺序交换
        tokenizer_outputs = tokenizer.encode_plus(list(tmp_text_query), list(tmp_text_context), padding='max_length',
                                                  max_length=max_length)
        # 处理slot_mask
        tmp_slot_mask = np.asarray([0] * max_length)
        tmp_slot_mask[1:len(tmp_text_query) + 1] = 1
        # tmp_slot_mask[len(tmp_text_context)+2: len(tmp_text_context)+len(tmp_text_query)+2] = 1
        tmp_slot_mask = list(tmp_slot_mask)

        # 处理tmp_label_slot
        tmp_label_slot = [0] * max_length
        tmp_label_slot[1:len(tmp_text_query) + 1] = tmp_label_slot_origin

        # 处理tmp_label_intent
        tmp_label_intent = np.asarray([0] * intent_label_encoder.vocab_size)
        tmp_label_intent[tmp_label_intent_origin] = 1
        tmp_label_intent = list(tmp_label_intent)

        feature = InputFeatures(**tokenizer_outputs, slot_mask=tmp_slot_mask, label_slot=tmp_label_slot,
                                label_intent=tmp_label_intent)
        features.append(feature)

    for i, example in enumerate(examples[:5]):
        logger.info("*** Example ***")
        logger.info("guid: %s" % example.guid)
        logger.info("features: %s" % features[i])
    return features
コード例 #2
0
def prepare_sample(
    sample: dict, text_encoder: WhitespaceEncoder, label_encoder: LabelEncoder,
    max_length: int
) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
    """
    Function that receives a sample from the Dataset iterator and prepares t
    he input to feed the transformer model.
    :param sample: dictionary containing the inputs to build the batch 
        (e.g: [{'source': 'This flight was amazing!', 'target': 'pos'}, 
               {'source': 'I hate Iberia', 'target': 'neg'}])
    :param text_encoder: Torch NLP text encoder for tokenization and vectorization.
    :param label_encoder: Torch NLP label encoder for vectorization of labels.
    :param max_length: Max length of the input sequences.
         If a sequence passes that value it is truncated.
    """
    sample = collate_tensors(sample)
    input_seqs, input_lengths = text_encoder.batch_encode(sample['source'])
    target_seqs = label_encoder.batch_encode(sample['target'])
    # Truncate Inputs
    if input_seqs.size(1) > max_length:
        input_seqs = input_seqs[:, :max_length]
    input_mask = lengths_to_mask(input_lengths).unsqueeze(1)
    return input_seqs, input_mask, target_seqs
コード例 #3
0
class Tagger(CaptionModelBase):
    """
    Tagger base class.

    :param hparams: HyperOptArgumentParser containing the hyperparameters.
    """
    def __init__(
        self,
        hparams: HyperOptArgumentParser,
    ) -> None:
        super().__init__(hparams)

    def _build_model(self):
        self.label_encoder = LabelEncoder(self.hparams.tag_set.split(","),
                                          reserved_labels=[])

    def _build_loss(self):
        """ Initializes the loss function/s. """
        weights = (np.array([
            float(x) for x in self.hparams.class_weights.split(",")
        ]) if self.hparams.class_weights != "ignore" else np.array([]))

        if self.hparams.loss == "cross_entropy":
            self.loss = nn.CrossEntropyLoss(
                reduction="sum",
                ignore_index=self.label_encoder.vocab_size,
                weight=torch.tensor(weights, dtype=torch.float32)
                if weights.any() else None,
            )
        else:
            raise Exception(f"{self.hparams.loss} is not a valid loss option.")

    def _retrieve_dataset(self, data_hparams, train=True, val=True, test=True):
        """ Retrieves task specific dataset """
        return sequence_tagging_dataset(data_hparams, train, val, test)

    @property
    def default_slot_index(self):
        """ Index of the default slot to be ignored. (e.g. 'O' in 'B-I-O' tags) """
        return 0

    def predict(self, sample: dict) -> list:
        """ Function that runs a model prediction,
        :param sample: a dictionary that must contain the the 'source' sequence.

        Return: list with predictions
        """
        if self.training:
            self.eval()

        return_dict = False
        if isinstance(sample, dict):
            sample = [sample]
            return_dict = True

        with torch.no_grad():
            model_input, _ = self.prepare_sample(sample, prepare_target=False)
            model_out = self.forward(**model_input)
            tag_logits = model_out["tags"]
            _, pred_labels = tag_logits.topk(1, dim=-1)

            for i in range(pred_labels.size(0)):
                sample_tags = pred_labels[i, :, :].view(-1)
                tags = [
                    self.label_encoder.index_to_token[sample_tags[j]]
                    for j in range(model_input["word_lengths"][i])
                ]
                sample[i]["predicted_tags"] = " ".join(tags)
                sample[i]["tagged_sequence"] = " ".join([
                    word + "/" + tag
                    for word, tag in zip(sample[i]["text"].split(), tags)
                ])

                sample[i][
                    "encoded_ground_truth_tags"] = self.label_encoder.batch_encode(
                        [tag for tag in sample[i]["tags"].split()])

                if self.hparams.ignore_last_tag:
                    if (sample[i]["encoded_ground_truth_tags"][
                            model_input["word_lengths"][i] - 1] == 1):
                        sample[i]["encoded_ground_truth_tags"][
                            model_input["word_lengths"][i] -
                            1] = self.label_encoder.vocab_size

        if return_dict:
            return sample[0]

        return sample

    def _compute_loss(self, model_out: dict, targets: dict) -> torch.tensor:
        """
        Computes Loss value according to a loss function.
        :param model_out: model specific output with predicted tag logits
            a tensor [batch_size x seq_length x num_tags]
        :param targets: Target tags [batch_size x seq_length]
        """
        logits = model_out["tags"].view(-1, model_out["tags"].size(-1))
        labels = targets["tags"].view(-1)
        return self.loss(logits, labels)

    def prepare_sample(self,
                       sample: list,
                       prepare_target: bool = True) -> (dict, dict):
        """
        Function that prepares a sample to input the model.
        :param sample: list of dictionaries.
        
        Returns:
            - dictionary with the expected model inputs.
            - dictionary with the expected target values.
        """
        sample = collate_tensors(sample)
        inputs = self.encoder.prepare_sample(sample["text"], trackpos=True)
        if not prepare_target:
            return inputs, {}

        tags, _ = stack_and_pad_tensors(
            [
                self.label_encoder.batch_encode(tags.split())
                for tags in sample["tags"]
            ],
            padding_index=self.label_encoder.vocab_size,
        )

        if self.hparams.ignore_first_title:
            first_tokens = tags[:, 0].clone()
            tags[:, 0] = first_tokens.masked_fill_(
                first_tokens == self._label_encoder.token_to_index["T"],
                self.label_encoder.vocab_size,
            )

        # TODO is this still needed ?
        if self.hparams.ignore_last_tag:
            lengths = [len(tags.split()) for tags in sample["tags"]]
            lengths = np.asarray(lengths)
            k = 0
            for length in lengths:
                if tags[k][length - 1] == 1:
                    tags[k][length - 1] = self.label_encoder.vocab_size
                k += 1

        targets = {"tags": tags}
        return inputs, targets

    def _compute_metrics(self, outputs: list) -> dict:
        """ 
        Private function that computes metrics of interest based on model predictions 
        and respective targets.
        """
        predictions = [
            batch_out["val_prediction"]["tags"] for batch_out in outputs
        ]
        targets = [batch_out["val_target"]["tags"] for batch_out in outputs]

        predicted_tags, ground_truth = [], []
        for i in range(len(predictions)):
            # Get logits and reshape predictions
            batch_predictions = predictions[i]
            logits = batch_predictions.view(-1,
                                            batch_predictions.size(-1)).cpu()
            _, pred_labels = logits.topk(1, dim=-1)

            # Reshape targets
            batch_targets = targets[i].view(-1).cpu()

            assert batch_targets.size() == pred_labels.view(-1).size()
            ground_truth.append(batch_targets)
            predicted_tags.append(pred_labels.view(-1))

        return classification_report(
            torch.cat(predicted_tags).numpy(),
            torch.cat(ground_truth).numpy(),
            padding=self.label_encoder.vocab_size,
            labels=self.label_encoder.token_to_index,
            ignore=self.default_slot_index,
        )

    @classmethod
    def add_model_specific_args(
            cls, parser: HyperOptArgumentParser) -> HyperOptArgumentParser:
        """ Parser for Estimator specific arguments/hyperparameters. 
        :param parser: HyperOptArgumentParser obj

        Returns:
            - updated parser
        """
        parser = super(Tagger, Tagger).add_model_specific_args(parser)
        parser.add_argument(
            "--tag_set",
            type=str,
            default="L,U,T",
            help="Task tags we want to use.\
                 Note that the 'default' label should appear first",
        )
        # Loss
        parser.add_argument(
            "--loss",
            default="cross_entropy",
            type=str,
            help="Loss function to be used.",
            choices=["cross_entropy"],
        )
        parser.add_argument(
            "--class_weights",
            default="ignore",
            type=str,
            help=
            'Weights for each of the classes we want to tag (e.g: "1.0,7.0,8.0").',
        )
        ## Data args:
        parser.add_argument(
            "--data_type",
            default="csv",
            type=str,
            help="The type of the file containing the training/dev/test data.",
            choices=["csv"],
        )
        parser.add_argument(
            "--train_path",
            default="data/dummy_train.csv",
            type=str,
            help="Path to the file containing the train data.",
        )
        parser.add_argument(
            "--dev_path",
            default="data/dummy_test.csv",
            type=str,
            help="Path to the file containing the dev data.",
        )
        parser.add_argument(
            "--test_path",
            default="data/dummy_test.csv",
            type=str,
            help="Path to the file containing the test data.",
        )
        parser.add_argument(
            "--loader_workers",
            default=0,
            type=int,
            help=("How many subprocesses to use for data loading. 0 means that"
                  "the data will be loaded in the main process."),
        )
        # Metric args:
        parser.add_argument(
            "--ignore_first_title",
            default=False,
            help="When used, this flag ignores T tags in the first position.",
            action="store_true",
        )
        parser.add_argument(
            "--ignore_last_tag",
            default=False,
            help="When used, this flag ignores S tags in the last position.",
            action="store_true",
        )
        return parser
コード例 #4
0
ファイル: main.py プロジェクト: zxin1023/PyTorch-NLP
    with open(fn, 'rb') as f:
        model, criterion, optimizer = torch.load(f)


from torchnlp import datasets
from torchnlp.encoders import LabelEncoder
from torchnlp.samplers import BPTTBatchSampler

print('Producing dataset...')
train, val, test = getattr(datasets, args.data)(train=True,
                                                dev=True,
                                                test=True)

encoder = LabelEncoder(train + val + test)

train_data = encoder.batch_encode(train)
val_data = encoder.batch_encode(val)
test_data = encoder.batch_encode(test)

eval_batch_size = 10
test_batch_size = 1

train_source_sampler, val_source_sampler, test_source_sampler = tuple([
    BPTTBatchSampler(d, args.bptt, args.batch_size, True, 'source')
    for d in (train, val, test)
])

train_target_sampler, val_target_sampler, test_target_sampler = tuple([
    BPTTBatchSampler(d, args.bptt, args.batch_size, True, 'target')
    for d in (train, val, test)
])
class BERTClassifier(pl.LightningModule):
    """
    Sample model to show how to use BERT to classify sentences.

    :param hparams: ArgumentParser containing the hyperparameters.
    """
    def __init__(self, hparams) -> None:
        super(BERTClassifier, self).__init__()
        self.hparams = hparams
        self.batch_size = hparams.batch_size

        # build model
        self.__build_model()

        # Loss criterion initialization.
        self.__build_loss()

        if hparams.nr_frozen_epochs > 0:
            self.freeze_encoder()
        else:
            self._frozen = False
        self.nr_frozen_epochs = hparams.nr_frozen_epochs

    def __build_model(self) -> None:
        """ Init BERT model + tokenizer + classification head."""
        self.bert = AutoModel.from_pretrained(self.hparams.encoder_model,
                                              output_hidden_states=True,
                                              num_labels=11)

        # set the number of features our encoder model will return...
        if self.hparams.encoder_model == "google/bert_uncased_L-2_H-128_A-2":
            self.encoder_features = 128
        else:
            self.encoder_features = 768

        # Tokenizer
        self.tokenizer = BERTTextEncoder("bert-base-uncased")

        # Label Encoder
        self.label_encoder = LabelEncoder(self.hparams.label_set.split(","),
                                          reserved_labels=[])
        self.label_encoder.unknown_index = None

        # Classification head
        self.classification_head = nn.Sequential(
            nn.Linear(self.encoder_features, self.encoder_features * 2),
            nn.Tanh(),
            nn.Linear(self.encoder_features * 2, self.encoder_features),
            nn.Tanh(),
            nn.Linear(self.encoder_features, self.label_encoder.vocab_size),
        )

    def __build_loss(self):
        """ Initializes the loss function/s. """
        self._loss = nn.CrossEntropyLoss()

    def unfreeze_encoder(self) -> None:
        """ un-freezes the encoder layer. """
        if self._frozen:
            log.info(f"\n-- Encoder model fine-tuning")
            for param in self.bert.parameters():
                param.requires_grad = True
            self._frozen = False

    def freeze_encoder(self) -> None:
        """ freezes the encoder layer. """
        for param in self.bert.parameters():
            param.requires_grad = False
        self._frozen = True

    def predict(self, sample: dict) -> dict:
        """ Predict function.
        :param sample: dictionary with the text we want to classify.

        Returns:
            Dictionary with the input text and the predicted label.
        """
        if self.training:
            self.eval()

        with torch.no_grad():
            model_input, _ = self.prepare_sample([sample],
                                                 prepare_target=False)
            model_out = self.forward(**model_input)
            logits = model_out["logits"].numpy()
            predicted_labels = [
                self.label_encoder.index_to_token[prediction]
                for prediction in np.argmax(logits, axis=1)
            ]
            sample["predicted_label"] = predicted_labels[0]

        return sample

    def forward(self, tokens, lengths):
        """ Usual pytorch forward function.
        :param tokens: text sequences [batch_size x src_seq_len]
        :param lengths: source lengths [batch_size]

        Returns:
            Dictionary with model outputs (e.g: logits)
        """
        tokens = tokens[:, :lengths.max()]
        # When using just one GPU this should not change behavior
        # but when splitting batches across GPU the tokens have padding
        # from the entire original batch
        mask = lengths_to_mask(lengths, device=tokens.device)
        # Run BERT model.
        word_embeddings = self.bert(tokens, mask)[0]

        # Average Pooling
        word_embeddings = mask_fill(0.0, tokens, word_embeddings,
                                    self.tokenizer.padding_index)
        sentemb = torch.sum(word_embeddings, 1)
        sum_mask = mask.unsqueeze(-1).expand(
            word_embeddings.size()).float().sum(1)
        sentemb = sentemb / sum_mask

        return {"logits": self.classification_head(sentemb)}

    def loss(self, predictions: dict, targets: dict) -> torch.tensor:
        """
        Computes Loss value according to a loss function.
        :param predictions: model specific output. Must contain a key 'logits' with
            a tensor [batch_size x 1] with model predictions
        :param labels: Label values [batch_size]

        Returns:
            torch.tensor with loss value.
        """
        return self._loss(predictions["logits"], targets["labels"])

    def prepare_sample(self,
                       sample: list,
                       prepare_target: bool = True) -> (dict, dict):
        """
        Function that prepares a sample to input the model.
        :param sample: list of dictionaries.

        Returns:
            - dictionary with the expected model inputs.
            - dictionary with the expected target labels.
        """
        sample = collate_tensors(sample)
        tokens, lengths = self.tokenizer.batch_encode(sample["text"])

        inputs = {"tokens": tokens, "lengths": lengths}

        if not prepare_target:
            return inputs, {}

        # Prepare target:
        try:
            targets = {
                "labels": self.label_encoder.batch_encode(sample["label"])
            }
            return inputs, targets
        except RuntimeError:
            raise Exception("Label encoder found an unknown label.")

    def training_step(self, batch: tuple, batch_nb: int, *args,
                      **kwargs) -> dict:
        """
        Runs one training step. This usually consists in the forward function followed
            by the loss function.

        :param batch: The output of your dataloader.
        :param batch_nb: Integer displaying which batch this is

        Returns:
            - dictionary containing the loss and the metrics to be added to the lightning logger.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_val = self.loss(model_out, targets)

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss_val = loss_val.unsqueeze(0)

        tqdm_dict = {"train_loss": loss_val}
        output = OrderedDict({
            "loss": loss_val,
            "progress_bar": tqdm_dict,
            "log": tqdm_dict
        })

        # can also return just a scalar instead of a dict (return loss_val)
        return output

    def validation_step(self, batch: tuple, batch_nb: int, *args,
                        **kwargs) -> dict:
        """ Similar to the training step but with the model in eval mode.

        Returns:
            - dictionary passed to the validation_end function.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_val = self.loss(model_out, targets)

        y = targets["labels"]
        y_hat = model_out["logits"]

        # acc
        labels_hat = torch.argmax(y_hat, dim=1)
        val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
        val_acc = torch.tensor(val_acc)

        if self.on_gpu:
            val_acc = val_acc.cuda(loss_val.device.index)

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss_val = loss_val.unsqueeze(0)
            val_acc = val_acc.unsqueeze(0)

        output = OrderedDict({
            "val_loss": loss_val,
            "val_acc": val_acc,
        })

        # can also return just a scalar instead of a dict (return loss_val)
        return output

    def validation_end(self, outputs: list) -> dict:
        """ Function that takes as input a list of dictionaries returned by the validation_step
        function and measures the model performance accross the entire validation set.

        Returns:
            - Dictionary with metrics to be added to the lightning logger.
        """
        val_loss_mean = 0
        val_acc_mean = 0
        for output in outputs:
            val_loss = output["val_loss"]

            # reduce manually when using dp
            if self.trainer.use_dp or self.trainer.use_ddp2:
                val_loss = torch.mean(val_loss)
            val_loss_mean += val_loss

            # reduce manually when using dp
            val_acc = output["val_acc"]
            if self.trainer.use_dp or self.trainer.use_ddp2:
                val_acc = torch.mean(val_acc)

            val_acc_mean += val_acc

        val_loss_mean /= len(outputs)
        val_acc_mean /= len(outputs)
        tqdm_dict = {"val_loss": val_loss_mean, "val_acc": val_acc_mean}
        result = {
            "progress_bar": tqdm_dict,
            "log": tqdm_dict,
            "val_loss": val_loss_mean,
        }
        return result

    def configure_optimizers(self):
        """ Sets different Learning rates for different parameter groups. """
        parameters = [
            {
                "params": self.classification_head.parameters()
            },
            {
                "params": self.bert.parameters(),
                "lr": self.hparams.encoder_learning_rate,
            },
        ]
        optimizer = optim.Adam(parameters, lr=self.hparams.learning_rate)
        return [optimizer], []

    def on_epoch_end(self):
        """ Pytorch lightning hook """
        if self.current_epoch + 1 >= self.nr_frozen_epochs:
            self.unfreeze_encoder()

    def __retrieve_dataset(self, train=True, val=True, test=True):
        """ Retrieves task specific dataset """
        return sentiment_analysis_dataset(self.hparams, train, val, test)

    @pl.data_loader
    def train_dataloader(self) -> DataLoader:
        """ Function that loads the train set. """
        self._train_dataset = self.__retrieve_dataset(val=False, test=False)[0]
        return DataLoader(
            dataset=self._train_dataset,
            sampler=RandomSampler(self._train_dataset),
            batch_size=self.hparams.batch_size,
            collate_fn=self.prepare_sample,
            num_workers=self.hparams.loader_workers,
        )

    @pl.data_loader
    def val_dataloader(self) -> DataLoader:
        """ Function that loads the validation set. """
        self._dev_dataset = self.__retrieve_dataset(train=False, test=False)[0]
        return DataLoader(
            dataset=self._dev_dataset,
            batch_size=self.hparams.batch_size,
            collate_fn=self.prepare_sample,
            num_workers=self.hparams.loader_workers,
        )

    @pl.data_loader
    def test_dataloader(self) -> DataLoader:
        """ Function that loads the validation set. """
        self._test_dataset = self.__retrieve_dataset(train=False, val=False)[0]
        return DataLoader(
            dataset=self._test_dataset,
            batch_size=self.hparams.batch_size,
            collate_fn=self.prepare_sample,
            num_workers=self.hparams.loader_workers,
        )

    @classmethod
    def add_model_specific_args(
            cls, parser: HyperOptArgumentParser) -> HyperOptArgumentParser:
        """ Parser for Estimator specific arguments/hyperparameters.
        :param parser: HyperOptArgumentParser obj

        Returns:
            - updated parser
        """
        parser.add_argument(
            "--encoder_model",
            default="bert-base-uncased",
            type=str,
            help="Encoder model to be used.",
        )
        parser.add_argument(
            "--encoder_learning_rate",
            default=1e-05,
            type=float,
            help="Encoder specific learning rate.",
        )
        parser.add_argument(
            "--learning_rate",
            default=3e-05,
            type=float,
            help="Classification head learning rate.",
        )
        parser.opt_list(
            "--nr_frozen_epochs",
            default=1,
            type=int,
            help="Number of epochs we want to keep the encoder model frozen.",
            tunable=True,
            options=[0, 1, 2, 3, 4, 5],
        )
        # Data Args:
        parser.add_argument(
            "--label_set",
            default="pos,neg",
            type=str,
            help="Classification labels set.",
        )
        parser.add_argument(
            "--train_csv",
            default="data/imdb_reviews_train.csv",
            type=str,
            help="Path to the file containing the train data.",
        )
        parser.add_argument(
            "--dev_csv",
            default="data/imdb_reviews_test.csv",
            type=str,
            help="Path to the file containing the dev data.",
        )
        parser.add_argument(
            "--test_csv",
            default="data/imdb_reviews_test.csv",
            type=str,
            help="Path to the file containing the dev data.",
        )
        parser.add_argument(
            "--loader_workers",
            default=8,
            type=int,
            help="How many subprocesses to use for data loading. 0 means that \
                the data will be loaded in the main process.",
        )
        return parser
コード例 #6
0
class UniProtData():
    def __init__(self, tokenizer, sequence_length):
        self.tokenizer = tokenizer
        self.sequence_length = sequence_length

    def from_file(self, fasta_file, reduce_to_binary=False):
        sequence_labels, sequence_strs = [], []
        cur_seq_label = None
        buf = []

        def _flush_current_seq(reduce_to_binary):
            nonlocal cur_seq_label, buf
            if cur_seq_label is None:
                return
            if reduce_to_binary:
                if cur_seq_label == "0":
                    sequence_labels.append(cur_seq_label)
                else:
                    sequence_labels.append("1")
            else:
                sequence_labels.append(cur_seq_label)
            sequence_strs.append("".join(buf))
            cur_seq_label = None
            buf = []

        with open(fasta_file, "r") as infile:
            for line_idx, line in enumerate(infile):
                if line.startswith(">"):  # label line
                    _flush_current_seq(reduce_to_binary)
                    line = line[1:].strip()
                    if len(line) > 0:
                        cur_seq_label = line.split("|")[-1]
                    else:
                        cur_seq_label = f"seqnum{line_idx:09d}"
                else:  # sequence line
                    buf.append(" ".join(line.strip()))

        _flush_current_seq(reduce_to_binary)

        # More usefull is to check if we have equal number of sequences and labels
        assert len(sequence_strs) == len(sequence_labels)

        # Create label encoder from unique strings in the label list
        self.label_encoder = LabelEncoder(np.unique(sequence_labels),
                                          reserved_labels=[])

        self.data = []
        for i in range(len(sequence_strs)):
            self.data.append({
                "seq": str(sequence_strs[i]),
                "label": str(sequence_labels[i])
            })

        return self.data

    def prepare_sample(self,
                       sample: list,
                       prepare_target: bool = True) -> (dict, dict):
        """
        Function that prepares a sample to input the model.
        :param sample: list of dictionaries.

        Returns:
            - dictionary with the expected model inputs.
            - dictionary with the expected target labels.
        """
        sample = collate_tensors(sample)

        # Tokenize the input, return dict with 3 entries:
        #   input_ids: tokenized matrix
        #   token_input_id: matrix of 0,1 indicating if the element belongs to seq0 or eq1
        #   attention_mask: matrix of 0,1 indicating if a token ist masked (0) or not (1)
        # Convert to PT tensor
        inputs = self.tokenizer.batch_encode_plus(
            sample["seq"],
            add_special_tokens=True,
            padding=True,
            truncation=True,
            max_length=self.sequence_length,
            return_tensors="pt")

        if prepare_target is False:
            return inputs, {}

        # Prepare target:
        try:
            targets = {
                "labels": self.label_encoder.batch_encode(sample["label"])
            }
            return inputs, targets
        except RuntimeError:
            print(sample["label"])
            raise Exception("Label encoder found an unknown label.")

    def get_dataloader(self, file_path, batch_size, num_worker=4):
        data = self.from_file(file_path)

        data_loader = DataLoader(data,
                                 batch_size=batch_size,
                                 sampler=RandomSampler(data),
                                 collate_fn=self.prepare_sample,
                                 num_workers=num_worker)

        return data_loader