Exemplo n.º 1
0
    def __init__(self, kernel_count: int = None, kernel_size=None, positional_channels: int = None, sequence_type: str = None, device=None,
                 number_of_threads: int = None, random_seed: int = None, learning_rate: float = None, iteration_count: int = None,
                 l1_weight_decay: float = None, l2_weight_decay: float = None, batch_size: int = None, training_percentage: float = None,
                 evaluate_at: int = None, background_probabilities=None, result_path: Path = None):

        super().__init__()
        self.kernel_count = kernel_count
        self.kernel_size = kernel_size
        self.positional_channels = positional_channels
        self.number_of_threads = number_of_threads
        self.random_seed = random_seed
        self.device = device
        self.l1_weight_decay = l1_weight_decay
        self.l2_weight_decay = l2_weight_decay
        self.learning_rate = learning_rate
        self.iteration_count = iteration_count
        self.batch_size = batch_size
        self.evaluate_at = evaluate_at
        self.training_percentage = training_percentage
        self.sequence_type = SequenceType[sequence_type.upper()]
        self.background_probabilities = background_probabilities if background_probabilities is not None \
            else np.array([1. / len(EnvironmentSettings.get_sequence_alphabet(self.sequence_type))
                           for i in range(len(EnvironmentSettings.get_sequence_alphabet(self.sequence_type)))])
        self.CNN = None
        self.label_name = None
        self.class_mapping = None
        self.result_path = result_path
        self.chain_names = None
        self.feature_names = None
Exemplo n.º 2
0
    def drop_illegal_character_sequences(
            dataframe: pd.DataFrame,
            import_illegal_characters: bool) -> pd.DataFrame:
        if not import_illegal_characters:
            sequence_type = EnvironmentSettings.get_sequence_type()
            sequence_name = sequence_type.name.lower().replace("_", " ")

            legal_alphabet = EnvironmentSettings.get_sequence_alphabet(
                sequence_type)
            if sequence_type == SequenceType.AMINO_ACID:
                legal_alphabet.append(Constants.STOP_CODON)

            is_illegal_seq = [
                ImportHelper.is_illegal_sequence(sequence, legal_alphabet)
                for sequence in dataframe[sequence_type.value]
            ]
            n_illegal = sum(is_illegal_seq)

            if n_illegal > 0:
                dataframe.drop(dataframe.loc[is_illegal_seq].index,
                               inplace=True)
                warnings.warn(
                    f"{ImportHelper.__name__}: {n_illegal} sequences were removed from the dataset because their {sequence_name} sequence contained illegal characters. "
                )
        return dataframe
Exemplo n.º 3
0
    def __init__(self,
                 use_positional_info: bool,
                 distance_to_seq_middle: int,
                 flatten: bool,
                 name: str = None,
                 sequence_type: SequenceType = None):
        self.use_positional_info = use_positional_info
        self.distance_to_seq_middle = distance_to_seq_middle
        self.flatten = flatten
        self.sequence_type = sequence_type
        self.alphabet = EnvironmentSettings.get_sequence_alphabet(
            self.sequence_type)

        if distance_to_seq_middle:
            self.pos_increasing = [
                1 / self.distance_to_seq_middle * i
                for i in range(self.distance_to_seq_middle)
            ]
            self.pos_decreasing = self.pos_increasing[::-1]
        else:
            self.pos_decreasing = None

        self.name = name

        if self.sequence_type == SequenceType.NUCLEOTIDE and self.distance_to_seq_middle is not None:  # todo check this / explain in docs
            self.distance_to_seq_middle = self.distance_to_seq_middle * 3

        self.onehot_dimensions = self.alphabet + [
            "start", "mid", "end"
        ] if self.use_positional_info else self.alphabet  # todo test this
Exemplo n.º 4
0
    def __init__(self, kernel_count: int, kernel_size, positional_channels: int, sequence_type: SequenceType, background_probabilities, chain_names):
        super(PyTorchReceptorCNN, self).__init__()
        self.background_probabilities = background_probabilities
        self.threshold = 0.1
        self.pseudocount = 0.05
        self.in_channels = len(EnvironmentSettings.get_sequence_alphabet(sequence_type)) + positional_channels
        self.positional_channels = positional_channels
        self.max_information_gain = self.get_max_information_gain()
        self.chain_names = chain_names

        self.conv_chain_1 = [f"chain_1_kernel_{size}" for size in kernel_size]
        self.conv_chain_2 = [f"chain_2_kernel_{size}" for size in kernel_size]

        for size in kernel_size:
            # chain 1
            setattr(self, f"chain_1_kernel_{size}", nn.Conv1d(in_channels=self.in_channels, out_channels=kernel_count, kernel_size=size,
                                                              bias=True))
            getattr(self, f"chain_1_kernel_{size}").weight.data. \
                normal_(0.0, np.sqrt(1 / np.prod(getattr(self, f"chain_1_kernel_{size}").weight.shape)))

            # chain 2
            setattr(self, f"chain_2_kernel_{size}", nn.Conv1d(in_channels=self.in_channels, out_channels=kernel_count, kernel_size=size,
                                                              bias=True))
            getattr(self, f"chain_2_kernel_{size}").weight.data. \
                normal_(0.0, np.sqrt(1 / np.prod(getattr(self, f"chain_2_kernel_{size}").weight.shape)))

        self.fully_connected = nn.Linear(in_features=kernel_count * len(kernel_size) * 2, out_features=1, bias=True)
        self.fully_connected.weight.data.normal_(0.0, np.sqrt(1 / np.prod(self.fully_connected.weight.shape)))
Exemplo n.º 5
0
    def make_random_dataset(self, path):
        alphabet = EnvironmentSettings.get_sequence_alphabet()
        sequences = [["".join([rn.choice(alphabet) for i in range(20)]) for i in range(100)] for i in range(40)]

        repertoires, metadata = RepertoireBuilder.build(sequences, path, subject_ids=[i % 2 for i in range(len(sequences))])
        dataset = RepertoireDataset(repertoires=repertoires, metadata_file=metadata)
        PickleExporter.export(dataset, path)
Exemplo n.º 6
0
    def create_model(self, dataset: RepertoireDataset, k: int, vector_size: int, batch_size: int, model_path: Path):
        model = Word2Vec(size=vector_size, min_count=1, window=5)  # creates an empty model
        all_kmers = KmerHelper.create_all_kmers(k=k, alphabet=EnvironmentSettings.get_sequence_alphabet())
        all_kmers = [[kmer] for kmer in all_kmers]
        model.build_vocab(all_kmers)

        for repertoire in dataset.get_data(batch_size=batch_size):
            sentences = KmerHelper.create_sentences_from_repertoire(repertoire=repertoire, k=k)
            model.train(sentences=sentences, total_words=len(all_kmers), epochs=15)

        model.save(str(model_path))

        return model
Exemplo n.º 7
0
    def create_model(self, dataset: RepertoireDataset, k: int,
                     vector_size: int, batch_size: int, model_path: Path):

        model = Word2Vec(size=vector_size, min_count=1,
                         window=5)  # creates an empty model
        all_kmers = KmerHelper.create_all_kmers(
            k=k, alphabet=EnvironmentSettings.get_sequence_alphabet())
        all_kmers = [[kmer] for kmer in all_kmers]
        model.build_vocab(all_kmers)

        for kmer in all_kmers:
            sentences = KmerHelper.create_kmers_within_HD(
                kmer=kmer[0],
                alphabet=EnvironmentSettings.get_sequence_alphabet(),
                distance=1)
            model.train(sentences=sentences,
                        total_words=len(all_kmers),
                        epochs=model.epochs)

        model.save(str(model_path))

        return model
    def test_receptor_flattened(self):
        path = EnvironmentSettings.root_path / "test/tmp/onehot_recep_flat/"

        PathBuilder.build(path)

        dataset = self.construct_test_flatten_dataset(path)

        encoder = OneHotEncoder.build_object(
            dataset, **{
                "use_positional_info": False,
                "distance_to_seq_middle": None,
                'sequence_type': 'amino_acid',
                "flatten": True
            })

        encoded_data = encoder.encode(
            dataset,
            EncoderParams(result_path=path,
                          label_config=LabelConfiguration([
                              Label(name="l1",
                                    values=[1, 0],
                                    positive_class="1")
                          ]),
                          pool_size=1,
                          learn_model=True,
                          model={},
                          filename="dataset.pkl"))

        self.assertTrue(isinstance(encoded_data, ReceptorDataset))

        onehot_a = [1.0] + [0.0] * 19
        onehot_t = [0.0] * 16 + [1.0] + [0] * 3

        self.assertListEqual(
            list(encoded_data.encoded_data.examples[0]),
            onehot_a + onehot_a + onehot_a + onehot_t + onehot_t + onehot_t +
            onehot_a + onehot_t + onehot_a + onehot_t + onehot_a + onehot_t)
        self.assertListEqual(list(encoded_data.encoded_data.examples[1]),
                             onehot_a * 12)
        self.assertListEqual(list(encoded_data.encoded_data.examples[2]),
                             onehot_a * 12)

        self.assertListEqual(list(encoded_data.encoded_data.feature_names), [
            f"{chain}_{pos}_{char}" for chain in ("alpha", "beta")
            for pos in range(6)
            for char in EnvironmentSettings.get_sequence_alphabet()
        ])

        shutil.rmtree(path)
Exemplo n.º 9
0
    def _generate(self) -> ReportResult:
        PathBuilder.build(self.result_path)
        report_result = ReportResult()
        sequence_alphabet = EnvironmentSettings.get_sequence_alphabet(
            self.method.sequence_type)
        for kernel_name in self.method.CNN.conv_chain_1 + self.method.CNN.conv_chain_2:
            figure_outputs, table_outputs = self._plot_kernels(
                kernel_name, sequence_alphabet)
            report_result.output_figures.extend(figure_outputs)
            report_result.output_tables.extend(table_outputs)

        figure_output, table_output = self._plot_fc_layer()
        report_result.output_figures.append(figure_output)
        report_result.output_tables.append(table_output)

        return report_result
Exemplo n.º 10
0
    def __init__(self, hamming_distance_probabilities: dict = None, min_gap: int = 0, max_gap: int = 0,
                 alphabet_weights: dict = None, position_weights: dict = None):
        if hamming_distance_probabilities is not None:
            hamming_distance_probabilities = {key: float(value) for key, value in hamming_distance_probabilities.items()}
            assert all(isinstance(key, int) for key in hamming_distance_probabilities.keys()) \
                   and all(isinstance(val, float) for val in hamming_distance_probabilities.values()) \
                   and 0.99 <= sum(hamming_distance_probabilities.values()) <= 1, \
                "GappedKmerInstantiation: for each possible Hamming distance a probability between 0 and 1 has to be assigned " \
                "so that the probabilities for all distance possibilities sum to 1."

        self._hamming_distance_probabilities = hamming_distance_probabilities
        self.position_weights = position_weights
        # if weights are not given for each letter of the alphabet, distribute the remaining probability
        # equally among letters
        self.alphabet_weights = self.set_default_weights(alphabet_weights, EnvironmentSettings.get_sequence_alphabet())
        self._min_gap = min_gap
        self._max_gap = max_gap
Exemplo n.º 11
0
    def _substitute_letters(self, position_weights, alphabet_weights, allowed_positions: list, instance: list):

        if self._hamming_distance_probabilities:
            substitution_count = random.choices(list(self._hamming_distance_probabilities.keys()),
                                                list(self._hamming_distance_probabilities.values()), k=1)[0]
            allowed_position_weights = {key: value for key, value in position_weights.items() if key in allowed_positions}
            position_probabilities = self._prepare_probabilities(allowed_position_weights)
            positions = list(np.random.choice(allowed_positions, size=substitution_count, p=position_probabilities))

            while substitution_count > 0:
                if position_weights[positions[substitution_count - 1]] > 0:  # if the position is allowed to be changed
                    position = positions[substitution_count - 1]
                    alphabet_probabilities = self._prepare_probabilities(alphabet_weights)
                    instance[position] = np.random.choice(EnvironmentSettings.get_sequence_alphabet(), size=1,
                                                          p=alphabet_probabilities)[0]
                substitution_count -= 1

        return instance
Exemplo n.º 12
0
    def generate_repertoire_dataset(repertoire_count: int,
                                    sequence_count_probabilities: dict,
                                    sequence_length_probabilities: dict,
                                    labels: dict,
                                    path: Path) -> RepertoireDataset:
        """
        Creates repertoire_count repertoires where the number of sequences per repertoire is sampled from the probability distribution given
        in sequence_count_probabilities. The length of sequences is sampled independently for each sequence from
        sequence_length_probabilities distribution. The labels are also randomly assigned to repertoires from the distribution given in
        labels. In this case, labels are multi-class, so each repertoire will get at one class from each label. This means that negative
        classes for the labels should be included as well in the specification.

        An example of input parameters is given below:
        repertoire_count: 100 # generate 100 repertoires
        sequence_count_probabilities:
            100: 0.5 # half of the generated repertoires will have 100 sequences
            200: 0.5 # the other half of the generated repertoires will have 200 sequences
        sequence_length_distribution:
            14: 0.8 # 80% of all generated sequences for all repertoires will have length 14
            15: 0.2 # 20% of all generated sequences across all repertoires will have length 15
        labels:
            cmv: # label name
                True: 0.5 # 50% of the repertoires will have class True
                False: 0.5 # 50% of the repertoires will have class False
            coeliac: # next label with classes that will be assigned to repertoires independently of the previous label or any other parameter
                1: 0.3 # 30% of the generated repertoires will have class 1
                0: 0.7 # 70% of the generated repertoires will have class 0
        """
        RandomDatasetGenerator._check_rep_dataset_generation_params(
            repertoire_count, sequence_count_probabilities,
            sequence_length_probabilities, labels, path)

        alphabet = EnvironmentSettings.get_sequence_alphabet()
        PathBuilder.build(path)

        sequences = [[
            "".join(
                random.choices(alphabet,
                               k=random.choices(
                                   list(sequence_length_probabilities.keys()),
                                   sequence_length_probabilities.values())[0]))
            for seq_count in range(
                random.choices(list(sequence_count_probabilities.keys()),
                               sequence_count_probabilities.values())[0])
        ] for rep in range(repertoire_count)]

        if labels is not None:
            processed_labels = {
                label: random.choices(list(labels[label].keys()),
                                      labels[label].values(),
                                      k=repertoire_count)
                for label in labels
            }
            dataset_params = {
                label: list(labels[label].keys())
                for label in labels
            }
        else:
            processed_labels = None
            dataset_params = None

        repertoires, metadata = RepertoireBuilder.build(
            sequences=sequences, path=path, labels=processed_labels)
        dataset = RepertoireDataset(labels=dataset_params,
                                    repertoires=repertoires,
                                    metadata_file=metadata)

        return dataset
Exemplo n.º 13
0
    def generate_sequence_dataset(sequence_count: int,
                                  length_probabilities: dict, labels: dict,
                                  path: Path):
        """
        Creates sequence_count receptor sequences (single chain) where the length of sequences in each chain is sampled independently for each sequence from
        length_probabilities distribution. The labels are also randomly assigned to sequences from the distribution given in
        labels. In this case, labels are multi-class, so each sequences will get one class from each label. This means that negative
        classes for the labels should be included as well in the specification.

        An example of input parameters is given below:

        sequence_count: 100 # generate 100 TRB ReceptorSequences
        length_probabilities:
            14: 0.8 # 80% of all generated sequences for all receptors (for chain 1) will have length 14
            15: 0.2 # 20% of all generated sequences across all receptors (for chain 1) will have length 15
        labels:
            epitope1: # label name
                True: 0.5 # 50% of the receptors will have class True
                False: 0.5 # 50% of the receptors will have class False
            epitope2: # next label with classes that will be assigned to receptors independently of the previous label or other parameters
                1: 0.3 # 30% of the generated receptors will have class 1
                0: 0.7 # 70% of the generated receptors will have class 0
        """
        RandomDatasetGenerator._check_sequence_dataset_generation_params(
            sequence_count, length_probabilities, labels, path)

        alphabet = EnvironmentSettings.get_sequence_alphabet()
        PathBuilder.build(path)

        chain = "TRB"

        sequences = [
            ReceptorSequence(
                "".join(
                    random.choices(alphabet,
                                   k=random.choices(
                                       list(length_probabilities.keys()),
                                       length_probabilities.values())[0])),
                metadata=SequenceMetadata(
                    count=1,
                    v_subgroup=chain + "V1",
                    v_gene=chain + "V1-1",
                    v_allele=chain + "V1-1*01",
                    j_subgroup=chain + "J1",
                    j_gene=chain + "J1-1",
                    j_allele=chain + "J1-1*01",
                    chain=chain,
                    custom_params={
                        **{
                            label: random.choices(list(label_dict.keys()),
                                                  label_dict.values(),
                                                  k=1)[0]
                            for label, label_dict in labels.items()
                        },
                        **{
                            "subject": f"subj_{i + 1}"
                        }
                    })) for i in range(sequence_count)
        ]

        filename = path / "batch01.npy"

        sequence_matrix = np.core.records.fromrecords(
            [seq.get_record() for seq in sequences],
            names=ReceptorSequence.get_record_names())
        np.save(str(filename), sequence_matrix, allow_pickle=False)

        return SequenceDataset(labels={
            label: list(label_dict.keys())
            for label, label_dict in labels.items()
        },
                               filenames=[filename],
                               file_size=sequence_count)
Exemplo n.º 14
0
    def generate_receptor_dataset(receptor_count: int,
                                  chain_1_length_probabilities: dict,
                                  chain_2_length_probabilities: dict,
                                  labels: dict, path: Path):
        """
        Creates receptor_count receptors where the length of sequences in each chain is sampled independently for each sequence from
        chain_n_length_probabilities distribution. The labels are also randomly assigned to receptors from the distribution given in
        labels. In this case, labels are multi-class, so each receptor will get one class from each label. This means that negative
        classes for the labels should be included as well in the specification. chain 1 and 2 in this case refer to alpha and beta
        chain of a T-cell receptor.

        An example of input parameters is given below:

        receptor_count: 100 # generate 100 TRABReceptors
        chain_1_length_probabilities:
            14: 0.8 # 80% of all generated sequences for all receptors (for chain 1) will have length 14
            15: 0.2 # 20% of all generated sequences across all receptors (for chain 1) will have length 15
        chain_2_length_probabilities:
            14: 0.8 # 80% of all generated sequences for all receptors (for chain 2) will have length 14
            15: 0.2 # 20% of all generated sequences across all receptors (for chain 2) will have length 15
        labels:
            epitope1: # label name
                True: 0.5 # 50% of the receptors will have class True
                False: 0.5 # 50% of the receptors will have class False
            epitope2: # next label with classes that will be assigned to receptors independently of the previous label or other parameters
                1: 0.3 # 30% of the generated receptors will have class 1
                0: 0.7 # 70% of the generated receptors will have class 0
        """
        RandomDatasetGenerator._check_receptor_dataset_generation_params(
            receptor_count, chain_1_length_probabilities,
            chain_2_length_probabilities, labels, path)

        alphabet = EnvironmentSettings.get_sequence_alphabet()
        PathBuilder.build(path)

        get_random_sequence = lambda proba, chain, id: ReceptorSequence(
            "".join(
                random.choices(alphabet,
                               k=random.choices(list(proba.keys()),
                                                proba.values())[0])),
            metadata=SequenceMetadata(count=1,
                                      v_subgroup=chain + "V1",
                                      v_gene=chain + "V1-1",
                                      v_allele=chain + "V1-1*01",
                                      j_subgroup=chain + "J1",
                                      j_gene=chain + "J1-1",
                                      j_allele=chain + "J1-1*01",
                                      chain=chain,
                                      cell_id=id))

        receptors = [
            TCABReceptor(alpha=get_random_sequence(
                chain_1_length_probabilities, "TRA", i),
                         beta=get_random_sequence(chain_2_length_probabilities,
                                                  "TRB", i),
                         metadata={
                             **{
                                 label: random.choices(list(label_dict.keys()),
                                                       label_dict.values(),
                                                       k=1)[0]
                                 for label, label_dict in labels.items()
                             },
                             **{
                                 "subject": f"subj_{i + 1}"
                             }
                         }) for i in range(receptor_count)
        ]

        filename = path / "batch01.npy"

        receptor_matrix = np.core.records.fromrecords(
            [receptor.get_record() for receptor in receptors],
            names=TCABReceptor.get_record_names())
        np.save(str(filename), receptor_matrix, allow_pickle=False)

        return ReceptorDataset(labels={
            label: list(label_dict.keys())
            for label, label_dict in labels.items()
        },
                               filenames=[filename],
                               file_size=receptor_count,
                               element_class_name=type(receptors[0]).__name__
                               if len(receptors) > 0 else None)
Exemplo n.º 15
0
    def test_repertoire_flattened(self):
        path = EnvironmentSettings.root_path / "test/tmp/onehot_recep_flat/"

        PathBuilder.build(path)

        dataset, lc = self._construct_test_repertoiredataset(path, positional=False)

        encoder = OneHotEncoder.build_object(dataset, **{"use_positional_info": False, "distance_to_seq_middle": None,
                                                         "flatten": True})

        encoded_data = encoder.encode(dataset, EncoderParams(
            result_path=path,
            label_config=lc,
            pool_size=1,
            learn_model=True,
            model={},
            filename="dataset.pkl"
        ))

        self.assertTrue(isinstance(encoded_data, RepertoireDataset))

        onehot_a = [1.0] + [0.0] * 19
        onehot_t = [0.0] * 16 + [1.0] + [0] * 3
        onehot_empty = [0] * 20


        self.assertListEqual(list(encoded_data.encoded_data.examples[0]), onehot_a+onehot_a+onehot_a+onehot_a+onehot_a+onehot_t+onehot_a+onehot_empty+onehot_a+onehot_t+onehot_a+onehot_empty)
        self.assertListEqual(list(encoded_data.encoded_data.examples[1]), onehot_a+onehot_t+onehot_a+onehot_empty+onehot_t+onehot_a+onehot_a+onehot_empty+onehot_empty+onehot_empty+onehot_empty+onehot_empty)

        self.assertListEqual(list(encoded_data.encoded_data.feature_names), [f"{seq}_{pos}_{char}" for seq in range(3) for pos in range(4) for char in EnvironmentSettings.get_sequence_alphabet()])

        shutil.rmtree(path)
Exemplo n.º 16
0
class OneHotEncoder(DatasetEncoder):
    """
    One-hot encoding for repertoires, sequences or receptors. In one-hot encoding, each alphabet character
    (amino acid or nucleotide) is replaced by a sparse vector with one 1 and the rest zeroes. The position of the
    1 represents the alphabet character.


    Arguments:

        use_positional_info (bool): whether to include features representing the positional information.
        If True, three additional feature vectors will be added, representing the sequence start, sequence middle
        and sequence end. The values in these features are scaled between 0 and 1. A graphical representation of
        the values of these vectors is given below.

        .. code-block:: console

              Value of sequence start:         Value of sequence middle:        Value of sequence end:

            1 \                              1    /‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾\         1                          /
               \                                 /                   \                                  /
                \                               /                     \                                /
            0    \_____________________      0 /                       \      0  _____________________/
              <----sequence length---->        <----sequence length---->         <----sequence length---->


        distance_to_seq_middle (int): only applies when use_positional_info is True. This is the distance from the edge
        of the CDR3 sequence (IMGT positions 105 and 117) to the portion of the sequence that is considered 'middle'.
        For example: if distance_to_seq_middle is 6 (default), all IMGT positions in the interval [111, 112)
        receive positional value 1.
        When using nucleotide sequences: note that the distance is measured in (amino acid) IMGT positions.
        If the complete sequence length is smaller than 2 * distance_to_seq_middle, the maximum value of the
        'start' and 'end' vectors will not reach 0, and the maximum value of the 'middle' vector will not reach 1.
        A graphical representation of the positional vectors with a too short sequence is given below:


        .. code-block:: console

            Value of sequence start         Value of sequence middle        Value of sequence end:
            with very short sequence:       with very short sequence:       with very short sequence:

                 1 \                               1                                 1    /
                    \                                                                    /
                     \                                /\                                /
                 0                                 0 /  \                            0
                   <->                               <-->                               <->

        flatten (bool): whether to flatten the final onehot matrix to a 2-dimensional matrix [examples, other_dims_combined]
        This must be set to True when using onehot encoding in combination with scikit-learn ML methods (inheriting :py:obj:`~source.ml_methods.SklearnMethod.SklearnMethod`),
        such as :ref:`LogisticRegression`, :ref:`SVM`, :ref:`RandomForestClassifier` and :ref:`KNN`.


    YAML specification:

    .. indent with spaces
    .. code-block:: yaml

        one_hot_vanilla:
            OneHot:
                use_positional_info: False
                flatten: False

        one_hot_positional:
            OneHot:
                use_positional_info: True
                distance_to_seq_middle: 3
                flatten: False

    """

    dataset_mapping = {
        "RepertoireDataset": "OneHotRepertoireEncoder",
        "SequenceDataset": "OneHotSequenceEncoder",
        "ReceptorDataset": "OneHotReceptorEncoder"
    }

    ALPHABET = EnvironmentSettings.get_sequence_alphabet()

    def __init__(self,
                 use_positional_info: bool,
                 distance_to_seq_middle: int,
                 flatten: bool,
                 name: str = None):
        self.use_positional_info = use_positional_info
        self.distance_to_seq_middle = distance_to_seq_middle
        self.flatten = flatten

        if distance_to_seq_middle:
            self.pos_increasing = [
                1 / self.distance_to_seq_middle * i
                for i in range(self.distance_to_seq_middle)
            ]
            self.pos_decreasing = self.pos_increasing[::-1]
        else:
            self.pos_decreasing = None

        self.name = name

        if EnvironmentSettings.get_sequence_type(
        ) == SequenceType.NUCLEOTIDE:  # todo check this / explain in docs
            self.distance_to_seq_middle = self.distance_to_seq_middle * 3

        self.onehot_dimensions = self.ALPHABET + [
            "start", "mid", "end"
        ] if self.use_positional_info else self.ALPHABET  # todo test this

    @staticmethod
    def _prepare_parameters(use_positional_info,
                            distance_to_seq_middle,
                            flatten,
                            name: str = None):

        location = OneHotEncoder.__name__

        ParameterValidator.assert_type_and_value(use_positional_info, bool,
                                                 location,
                                                 "use_positional_info")
        if use_positional_info:
            ParameterValidator.assert_type_and_value(distance_to_seq_middle,
                                                     int,
                                                     location,
                                                     "distance_to_seq_middle",
                                                     min_inclusive=1)
        else:
            distance_to_seq_middle = None

        ParameterValidator.assert_type_and_value(flatten, bool, location,
                                                 "flatten")

        return {
            "use_positional_info": use_positional_info,
            "distance_to_seq_middle": distance_to_seq_middle,
            "flatten": flatten,
            "name": name
        }

    @staticmethod
    def build_object(dataset=None, **params):

        try:
            prepared_params = OneHotEncoder._prepare_parameters(**params)
            encoder = ReflectionHandler.get_class_by_name(
                OneHotEncoder.dataset_mapping[dataset.__class__.__name__],
                "onehot/")(**prepared_params)
        except ValueError:
            raise ValueError(
                "{} is not defined for dataset of type {}.".format(
                    OneHotEncoder.__name__, dataset.__class__.__name__))
        return encoder

    def encode(self, dataset, params: EncoderParams):
        encoded_dataset = CacheHandler.memo_by_params(
            self._prepare_caching_params(dataset, params),
            lambda: self._encode_new_dataset(dataset, params))

        return encoded_dataset

    def _prepare_caching_params(self,
                                dataset,
                                params: EncoderParams,
                                step: str = ""):
        return (("example_identifiers", tuple(dataset.get_example_ids())),
                ("dataset_metadata",
                 dataset.metadata_file if hasattr(dataset, "metadata_file")
                 else None), ("dataset_type", dataset.__class__.__name__),
                ("labels", tuple(params.label_config.get_labels_by_name())),
                ("encoding", OneHotEncoder.__name__), ("learn_model",
                                                       params.learn_model),
                ("step", step), ("encoding_params", tuple(vars(self).items())))

    @abc.abstractmethod
    def _encode_new_dataset(self, dataset, params: EncoderParams):
        pass

    def store(self, encoded_dataset, params: EncoderParams):
        PickleExporter.export(encoded_dataset, params.result_path)

    def _encode_sequence_list(self, sequences, pad_n_sequences,
                              pad_sequence_len):
        char_array = np.array(sequences, dtype=str)
        char_array = char_array.view('U1').reshape((char_array.size, -1))

        n_sequences, sequence_len = char_array.shape

        sklearn_enc = SklearnOneHotEncoder(
            categories=[OneHotEncoder.ALPHABET for i in range(sequence_len)],
            handle_unknown='ignore')
        encoded_data = sklearn_enc.fit_transform(char_array).toarray()

        encoded_data = np.pad(encoded_data,
                              pad_width=((0, pad_n_sequences - n_sequences),
                                         (0, 0)))
        encoded_data = encoded_data.reshape(
            (pad_n_sequences, sequence_len, len(OneHotEncoder.ALPHABET)))
        positional_dims = int(self.use_positional_info) * 3
        encoded_data = np.pad(encoded_data,
                              pad_width=((0, 0),
                                         (0, pad_sequence_len - sequence_len),
                                         (0, positional_dims)))

        if self.use_positional_info:
            pos_info = [
                self._get_imgt_position_weights(len(sequence),
                                                pad_length=pad_sequence_len).T
                for sequence in sequences
            ]
            pos_info = np.stack(pos_info)
            pos_info = np.pad(pos_info,
                              pad_width=((0, pad_n_sequences - n_sequences),
                                         (0, 0), (0, 0)))

            encoded_data[:, :, len(OneHotEncoder.ALPHABET):] = pos_info

        return encoded_data

    def _get_imgt_position_weights(self, seq_length, pad_length=None):
        start_weights = self._get_imgt_start_weights(seq_length)
        mid_weights = self._get_imgt_mid_weights(seq_length)
        end_weights = start_weights[::-1]

        weights = np.array([start_weights, mid_weights, end_weights])

        if pad_length is not None:
            weights = np.pad(weights,
                             pad_width=((0, 0), (0, pad_length - seq_length)))

        return weights

    def _get_imgt_mid_weights(self, seq_length):
        mid_len = seq_length - (self.distance_to_seq_middle * 2)

        if mid_len >= 0:
            mid_weights = self.pos_increasing + [
                1
            ] * mid_len + self.pos_decreasing
        else:
            left_idx = math.ceil(seq_length / 2)
            right_idx = math.floor(seq_length / 2)

            mid_weights = self.pos_increasing[:left_idx] + self.pos_decreasing[
                -right_idx:]

        return mid_weights

    def _get_imgt_start_weights(self, seq_length):
        diff = (seq_length - self.distance_to_seq_middle) - 1
        if diff >= 0:
            start_weights = [1] + self.pos_decreasing + [0] * diff
        else:
            start_weights = [1] + self.pos_decreasing[:diff]

        return start_weights