def __init__(self, kernel_count: int = None, kernel_size=None, positional_channels: int = None, sequence_type: str = None, device=None, number_of_threads: int = None, random_seed: int = None, learning_rate: float = None, iteration_count: int = None, l1_weight_decay: float = None, l2_weight_decay: float = None, batch_size: int = None, training_percentage: float = None, evaluate_at: int = None, background_probabilities=None, result_path: Path = None): super().__init__() self.kernel_count = kernel_count self.kernel_size = kernel_size self.positional_channels = positional_channels self.number_of_threads = number_of_threads self.random_seed = random_seed self.device = device self.l1_weight_decay = l1_weight_decay self.l2_weight_decay = l2_weight_decay self.learning_rate = learning_rate self.iteration_count = iteration_count self.batch_size = batch_size self.evaluate_at = evaluate_at self.training_percentage = training_percentage self.sequence_type = SequenceType[sequence_type.upper()] self.background_probabilities = background_probabilities if background_probabilities is not None \ else np.array([1. / len(EnvironmentSettings.get_sequence_alphabet(self.sequence_type)) for i in range(len(EnvironmentSettings.get_sequence_alphabet(self.sequence_type)))]) self.CNN = None self.label_name = None self.class_mapping = None self.result_path = result_path self.chain_names = None self.feature_names = None
def get_repertoire_contents(repertoire, compairr_params): attributes = [EnvironmentSettings.get_sequence_type().value, "counts"] attributes += [] if compairr_params.ignore_genes else ["v_genes", "j_genes"] repertoire_contents = repertoire.get_attributes(attributes) repertoire_contents = pd.DataFrame({**repertoire_contents, "identifier": repertoire.identifier}) check_na_rows = [EnvironmentSettings.get_sequence_type().value] check_na_rows += [] if compairr_params.ignore_counts else ["counts"] check_na_rows += [] if compairr_params.ignore_genes else ["v_genes", "j_genes"] n_rows_before = len(repertoire_contents) repertoire_contents.dropna(inplace=True, subset=check_na_rows) if n_rows_before > len(repertoire_contents): warnings.warn( f"CompAIRRHelper: removed {n_rows_before - len(repertoire_contents)} entries from repertoire {repertoire.identifier} due to missing values.") if compairr_params.ignore_counts: repertoire_contents["counts"] = 1 repertoire_contents.rename(columns={EnvironmentSettings.get_sequence_type().value: "junction_aa", "v_genes": "v_call", "j_genes": "j_call", "counts": "duplicate_count", "identifier": "repertoire_id"}, inplace=True) return repertoire_contents
def drop_illegal_character_sequences( dataframe: pd.DataFrame, import_illegal_characters: bool) -> pd.DataFrame: if not import_illegal_characters: sequence_type = EnvironmentSettings.get_sequence_type() sequence_name = sequence_type.name.lower().replace("_", " ") legal_alphabet = EnvironmentSettings.get_sequence_alphabet( sequence_type) if sequence_type == SequenceType.AMINO_ACID: legal_alphabet.append(Constants.STOP_CODON) is_illegal_seq = [ ImportHelper.is_illegal_sequence(sequence, legal_alphabet) for sequence in dataframe[sequence_type.value] ] n_illegal = sum(is_illegal_seq) if n_illegal > 0: dataframe.drop(dataframe.loc[is_illegal_seq].index, inplace=True) warnings.warn( f"{ImportHelper.__name__}: {n_illegal} sequences were removed from the dataset because their {sequence_name} sequence contained illegal characters. " ) return dataframe
def test_get_sequence(self): sequence = ReceptorSequence(amino_acid_sequence="CAS", nucleotide_sequence="TGTGCTTCC") EnvironmentSettings.set_sequence_type(SequenceType.AMINO_ACID) self.assertEqual(sequence.get_sequence(), "CAS")
def _build_filename(cache_key: str, object_type: CacheObjectType, cache_type=None) -> Path: path = EnvironmentSettings.get_cache_path( cache_type) / object_type.name.lower() PathBuilder.build(path) return path / f"{cache_key}.pickle"
def get_relevant_sequence_attributes(self): attributes = [EnvironmentSettings.get_sequence_type().value] if not self.compairr_params.ignore_genes: attributes += ["v_genes", "j_genes"] return attributes
def __init__(self, use_positional_info: bool, distance_to_seq_middle: int, flatten: bool, name: str = None, sequence_type: SequenceType = None): self.use_positional_info = use_positional_info self.distance_to_seq_middle = distance_to_seq_middle self.flatten = flatten self.sequence_type = sequence_type self.alphabet = EnvironmentSettings.get_sequence_alphabet( self.sequence_type) if distance_to_seq_middle: self.pos_increasing = [ 1 / self.distance_to_seq_middle * i for i in range(self.distance_to_seq_middle) ] self.pos_decreasing = self.pos_increasing[::-1] else: self.pos_decreasing = None self.name = name if self.sequence_type == SequenceType.NUCLEOTIDE and self.distance_to_seq_middle is not None: # todo check this / explain in docs self.distance_to_seq_middle = self.distance_to_seq_middle * 3 self.onehot_dimensions = self.alphabet + [ "start", "mid", "end" ] if self.use_positional_info else self.alphabet # todo test this
def __init__(self, kernel_count: int, kernel_size, positional_channels: int, sequence_type: SequenceType, background_probabilities, chain_names): super(PyTorchReceptorCNN, self).__init__() self.background_probabilities = background_probabilities self.threshold = 0.1 self.pseudocount = 0.05 self.in_channels = len(EnvironmentSettings.get_sequence_alphabet(sequence_type)) + positional_channels self.positional_channels = positional_channels self.max_information_gain = self.get_max_information_gain() self.chain_names = chain_names self.conv_chain_1 = [f"chain_1_kernel_{size}" for size in kernel_size] self.conv_chain_2 = [f"chain_2_kernel_{size}" for size in kernel_size] for size in kernel_size: # chain 1 setattr(self, f"chain_1_kernel_{size}", nn.Conv1d(in_channels=self.in_channels, out_channels=kernel_count, kernel_size=size, bias=True)) getattr(self, f"chain_1_kernel_{size}").weight.data. \ normal_(0.0, np.sqrt(1 / np.prod(getattr(self, f"chain_1_kernel_{size}").weight.shape))) # chain 2 setattr(self, f"chain_2_kernel_{size}", nn.Conv1d(in_channels=self.in_channels, out_channels=kernel_count, kernel_size=size, bias=True)) getattr(self, f"chain_2_kernel_{size}").weight.data. \ normal_(0.0, np.sqrt(1 / np.prod(getattr(self, f"chain_2_kernel_{size}").weight.shape))) self.fully_connected = nn.Linear(in_features=kernel_count * len(kernel_size) * 2, out_features=1, bias=True) self.fully_connected.weight.data.normal_(0.0, np.sqrt(1 / np.prod(self.fully_connected.weight.shape)))
def _build_new_sequence(self, sequence: ReceptorSequence, position, signal: dict) -> ReceptorSequence: gap_length = signal["motif_instance"].gap if "/" in signal["motif_instance"].instance: motif_left, motif_right = signal["motif_instance"].instance.split("/") else: motif_left = signal["motif_instance"].instance motif_right = "" gap_start = position+len(motif_left) gap_end = gap_start+gap_length part1 = sequence.get_sequence()[:position] part2 = sequence.get_sequence()[gap_start:gap_end] part3 = sequence.get_sequence()[gap_end+len(motif_right):] new_sequence_string = part1 + motif_left + part2 + motif_right + part3 annotation = SequenceAnnotation() implant = ImplantAnnotation(signal_id=signal["signal_id"], motif_id=signal["motif_id"], motif_instance=signal["motif_instance"], position=position) annotation.add_implant(implant) new_sequence = ReceptorSequence() new_sequence.set_annotation(annotation) new_sequence.set_metadata(copy.deepcopy(sequence.metadata)) new_sequence.set_sequence(new_sequence_string, EnvironmentSettings.get_sequence_type()) return new_sequence
def make_random_dataset(self, path): alphabet = EnvironmentSettings.get_sequence_alphabet() sequences = [["".join([rn.choice(alphabet) for i in range(20)]) for i in range(100)] for i in range(40)] repertoires, metadata = RepertoireBuilder.build(sequences, path, subject_ids=[i % 2 for i in range(len(sequences))]) dataset = RepertoireDataset(repertoires=repertoires, metadata_file=metadata) PickleExporter.export(dataset, path)
def _encode_repertoire(self, repertoire, params: EncoderParams): sequences = repertoire.get_attribute(EnvironmentSettings.get_sequence_type().value) onehot_encoded = self._encode_sequence_list(sequences, pad_n_sequences=self.max_rep_len, pad_sequence_len=self.max_seq_len) example_id = repertoire.identifier labels = self._get_repertoire_labels(repertoire, params) if params.encode_labels else None return onehot_encoded, example_id, labels
def get_sequence(self, sequence_type: SequenceType = None): """Returns receptor_sequence (nucleotide/amino acid) that corresponds to provided sequence type or preset receptor_sequence type from EnvironmentSettings class if no type is provided""" sequence_type_ = EnvironmentSettings.get_sequence_type() if sequence_type is None else sequence_type if sequence_type_ == SequenceType.AMINO_ACID: return self.amino_acid_sequence else: return self.nucleotide_sequence
def get_sequence(self): """ :return: receptor_sequence (nucleotide/amino acid) that corresponds to preset receptor_sequence type from EnvironmentSettings class """ if EnvironmentSettings.get_sequence_type() == SequenceType.AMINO_ACID: return self.amino_acid_sequence else: return self.nucleotide_sequence
def _set_max_dims(self, dataset): max_rep_len = 0 max_seq_len = 0 for repertoire in dataset.repertoires: sequences = repertoire.get_attribute(EnvironmentSettings.get_sequence_type().value) max_rep_len = max(len(sequences), max_rep_len) max_seq_len = max(max([len(seq) for seq in sequences]), max_seq_len) self.max_rep_len = max_rep_len self.max_seq_len = max_seq_len
def test_memo_with_object_type(self): fn = lambda: "abc" cache_key = "a123" obj = CacheHandler.memo(cache_key, fn, CacheObjectType.ENCODING) self.assertEqual("abc", obj) self.assertTrue( os.path.isfile(EnvironmentSettings.get_cache_path() / f"encoding/{cache_key}.pickle")) os.remove( CacheHandler._build_filename(cache_key, CacheObjectType.ENCODING))
def add(params: tuple, caching_object, object_type: CacheObjectType = CacheObjectType.OTHER, cache_type=None): PathBuilder.build(EnvironmentSettings.get_cache_path(cache_type)) h = CacheHandler.generate_cache_key(params) filename = CacheHandler._build_filename(cache_key=h, object_type=object_type, cache_type=cache_type) with filename.open("wb") as file: dill.dump(caching_object, file, protocol=pickle.HIGHEST_PROTOCOL)
def write_sequence_set_file(self, sequence_set, filename, offset=0): sequence_col = "junction_aa" if EnvironmentSettings.get_sequence_type( ) == SequenceType.AMINO_ACID else "junction" vj_header = "" if self.compairr_params.ignore_genes else "\tv_call\tj_call" with open(filename, "w") as file: file.write( f"{sequence_col}{vj_header}\tduplicate_count\trepertoire_id\n") for id, sequence_info in enumerate(sequence_set, offset): file.write("\t".join(sequence_info) + f"\t1\t{id}\n")
def create_model(self, dataset: RepertoireDataset, k: int, vector_size: int, batch_size: int, model_path: Path): model = Word2Vec(size=vector_size, min_count=1, window=5) # creates an empty model all_kmers = KmerHelper.create_all_kmers( k=k, alphabet=EnvironmentSettings.get_sequence_alphabet()) all_kmers = [[kmer] for kmer in all_kmers] model.build_vocab(all_kmers) for kmer in all_kmers: sentences = KmerHelper.create_kmers_within_HD( kmer=kmer[0], alphabet=EnvironmentSettings.get_sequence_alphabet(), distance=1) model.train(sentences=sentences, total_words=len(all_kmers), epochs=model.epochs) model.save(str(model_path)) return model
def create_model(self, dataset: RepertoireDataset, k: int, vector_size: int, batch_size: int, model_path: Path): model = Word2Vec(size=vector_size, min_count=1, window=5) # creates an empty model all_kmers = KmerHelper.create_all_kmers(k=k, alphabet=EnvironmentSettings.get_sequence_alphabet()) all_kmers = [[kmer] for kmer in all_kmers] model.build_vocab(all_kmers) for repertoire in dataset.get_data(batch_size=batch_size): sentences = KmerHelper.create_sentences_from_repertoire(repertoire=repertoire, k=k) model.train(sentences=sentences, total_words=len(all_kmers), epochs=15) model.save(str(model_path)) return model
def test_get(self): params = (("k1", 1), ("k2", 2)) obj = "object_example" object_type = CacheObjectType.OTHER h = hashlib.sha256(str(params).encode('utf-8')).hexdigest() filename = EnvironmentSettings.get_cache_path( ) / "{}/{}.pickle".format(CacheObjectType.OTHER.name.lower(), h) with open(filename, "wb") as file: pickle.dump(obj, file) obj2 = CacheHandler.get(params, object_type) self.assertEqual(obj, obj2) os.remove(filename)
def test_receptor_flattened(self): path = EnvironmentSettings.root_path / "test/tmp/onehot_recep_flat/" PathBuilder.build(path) dataset = self.construct_test_flatten_dataset(path) encoder = OneHotEncoder.build_object( dataset, **{ "use_positional_info": False, "distance_to_seq_middle": None, 'sequence_type': 'amino_acid', "flatten": True }) encoded_data = encoder.encode( dataset, EncoderParams(result_path=path, label_config=LabelConfiguration([ Label(name="l1", values=[1, 0], positive_class="1") ]), pool_size=1, learn_model=True, model={}, filename="dataset.pkl")) self.assertTrue(isinstance(encoded_data, ReceptorDataset)) onehot_a = [1.0] + [0.0] * 19 onehot_t = [0.0] * 16 + [1.0] + [0] * 3 self.assertListEqual( list(encoded_data.encoded_data.examples[0]), onehot_a + onehot_a + onehot_a + onehot_t + onehot_t + onehot_t + onehot_a + onehot_t + onehot_a + onehot_t + onehot_a + onehot_t) self.assertListEqual(list(encoded_data.encoded_data.examples[1]), onehot_a * 12) self.assertListEqual(list(encoded_data.encoded_data.examples[2]), onehot_a * 12) self.assertListEqual(list(encoded_data.encoded_data.feature_names), [ f"{chain}_{pos}_{char}" for chain in ("alpha", "beta") for pos in range(6) for char in EnvironmentSettings.get_sequence_alphabet() ]) shutil.rmtree(path)
def _generate(self) -> ReportResult: PathBuilder.build(self.result_path) report_result = ReportResult() sequence_alphabet = EnvironmentSettings.get_sequence_alphabet( self.method.sequence_type) for kernel_name in self.method.CNN.conv_chain_1 + self.method.CNN.conv_chain_2: figure_outputs, table_outputs = self._plot_kernels( kernel_name, sequence_alphabet) report_result.output_figures.extend(figure_outputs) report_result.output_tables.extend(table_outputs) figure_output, table_output = self._plot_fc_layer() report_result.output_figures.append(figure_output) report_result.output_tables.append(table_output) return report_result
def __init__(self, hamming_distance_probabilities: dict = None, min_gap: int = 0, max_gap: int = 0, alphabet_weights: dict = None, position_weights: dict = None): if hamming_distance_probabilities is not None: hamming_distance_probabilities = {key: float(value) for key, value in hamming_distance_probabilities.items()} assert all(isinstance(key, int) for key in hamming_distance_probabilities.keys()) \ and all(isinstance(val, float) for val in hamming_distance_probabilities.values()) \ and 0.99 <= sum(hamming_distance_probabilities.values()) <= 1, \ "GappedKmerInstantiation: for each possible Hamming distance a probability between 0 and 1 has to be assigned " \ "so that the probabilities for all distance possibilities sum to 1." self._hamming_distance_probabilities = hamming_distance_probabilities self.position_weights = position_weights # if weights are not given for each letter of the alphabet, distribute the remaining probability # equally among letters self.alphabet_weights = self.set_default_weights(alphabet_weights, EnvironmentSettings.get_sequence_alphabet()) self._min_gap = min_gap self._max_gap = max_gap
def _substitute_letters(self, position_weights, alphabet_weights, allowed_positions: list, instance: list): if self._hamming_distance_probabilities: substitution_count = random.choices(list(self._hamming_distance_probabilities.keys()), list(self._hamming_distance_probabilities.values()), k=1)[0] allowed_position_weights = {key: value for key, value in position_weights.items() if key in allowed_positions} position_probabilities = self._prepare_probabilities(allowed_position_weights) positions = list(np.random.choice(allowed_positions, size=substitution_count, p=position_probabilities)) while substitution_count > 0: if position_weights[positions[substitution_count - 1]] > 0: # if the position is allowed to be changed position = positions[substitution_count - 1] alphabet_probabilities = self._prepare_probabilities(alphabet_weights) instance[position] = np.random.choice(EnvironmentSettings.get_sequence_alphabet(), size=1, p=alphabet_probabilities)[0] substitution_count -= 1 return instance
def add_by_key(cache_key: str, caching_object, object_type: CacheObjectType = CacheObjectType.OTHER, cache_type=None): PathBuilder.build(EnvironmentSettings.get_cache_path(cache_type)) filename = CacheHandler._build_filename(cache_key=cache_key, object_type=object_type, cache_type=cache_type) try: with filename.open("wb") as file: dill.dump(caching_object, file, protocol=pickle.HIGHEST_PROTOCOL) except AttributeError: os.remove(filename) logging.warning( f"CacheHandler: could not cache object of class {type(caching_object).__name__} with key {cache_key}. " f"Object: {caching_object}\n" f"Next time this object is needed, it will be recomputed which will take more time but should not influence results." )
def generate_repertoire_dataset(repertoire_count: int, sequence_count_probabilities: dict, sequence_length_probabilities: dict, labels: dict, path: Path) -> RepertoireDataset: """ Creates repertoire_count repertoires where the number of sequences per repertoire is sampled from the probability distribution given in sequence_count_probabilities. The length of sequences is sampled independently for each sequence from sequence_length_probabilities distribution. The labels are also randomly assigned to repertoires from the distribution given in labels. In this case, labels are multi-class, so each repertoire will get at one class from each label. This means that negative classes for the labels should be included as well in the specification. An example of input parameters is given below: repertoire_count: 100 # generate 100 repertoires sequence_count_probabilities: 100: 0.5 # half of the generated repertoires will have 100 sequences 200: 0.5 # the other half of the generated repertoires will have 200 sequences sequence_length_distribution: 14: 0.8 # 80% of all generated sequences for all repertoires will have length 14 15: 0.2 # 20% of all generated sequences across all repertoires will have length 15 labels: cmv: # label name True: 0.5 # 50% of the repertoires will have class True False: 0.5 # 50% of the repertoires will have class False coeliac: # next label with classes that will be assigned to repertoires independently of the previous label or any other parameter 1: 0.3 # 30% of the generated repertoires will have class 1 0: 0.7 # 70% of the generated repertoires will have class 0 """ RandomDatasetGenerator._check_rep_dataset_generation_params( repertoire_count, sequence_count_probabilities, sequence_length_probabilities, labels, path) alphabet = EnvironmentSettings.get_sequence_alphabet() PathBuilder.build(path) sequences = [[ "".join( random.choices(alphabet, k=random.choices( list(sequence_length_probabilities.keys()), sequence_length_probabilities.values())[0])) for seq_count in range( random.choices(list(sequence_count_probabilities.keys()), sequence_count_probabilities.values())[0]) ] for rep in range(repertoire_count)] if labels is not None: processed_labels = { label: random.choices(list(labels[label].keys()), labels[label].values(), k=repertoire_count) for label in labels } dataset_params = { label: list(labels[label].keys()) for label in labels } else: processed_labels = None dataset_params = None repertoires, metadata = RepertoireBuilder.build( sequences=sequences, path=path, labels=processed_labels) dataset = RepertoireDataset(labels=dataset_params, repertoires=repertoires, metadata_file=metadata) return dataset
def generate_sequence_dataset(sequence_count: int, length_probabilities: dict, labels: dict, path: Path): """ Creates sequence_count receptor sequences (single chain) where the length of sequences in each chain is sampled independently for each sequence from length_probabilities distribution. The labels are also randomly assigned to sequences from the distribution given in labels. In this case, labels are multi-class, so each sequences will get one class from each label. This means that negative classes for the labels should be included as well in the specification. An example of input parameters is given below: sequence_count: 100 # generate 100 TRB ReceptorSequences length_probabilities: 14: 0.8 # 80% of all generated sequences for all receptors (for chain 1) will have length 14 15: 0.2 # 20% of all generated sequences across all receptors (for chain 1) will have length 15 labels: epitope1: # label name True: 0.5 # 50% of the receptors will have class True False: 0.5 # 50% of the receptors will have class False epitope2: # next label with classes that will be assigned to receptors independently of the previous label or other parameters 1: 0.3 # 30% of the generated receptors will have class 1 0: 0.7 # 70% of the generated receptors will have class 0 """ RandomDatasetGenerator._check_sequence_dataset_generation_params( sequence_count, length_probabilities, labels, path) alphabet = EnvironmentSettings.get_sequence_alphabet() PathBuilder.build(path) chain = "TRB" sequences = [ ReceptorSequence( "".join( random.choices(alphabet, k=random.choices( list(length_probabilities.keys()), length_probabilities.values())[0])), metadata=SequenceMetadata( count=1, v_subgroup=chain + "V1", v_gene=chain + "V1-1", v_allele=chain + "V1-1*01", j_subgroup=chain + "J1", j_gene=chain + "J1-1", j_allele=chain + "J1-1*01", chain=chain, custom_params={ **{ label: random.choices(list(label_dict.keys()), label_dict.values(), k=1)[0] for label, label_dict in labels.items() }, **{ "subject": f"subj_{i + 1}" } })) for i in range(sequence_count) ] filename = path / "batch01.npy" sequence_matrix = np.core.records.fromrecords( [seq.get_record() for seq in sequences], names=ReceptorSequence.get_record_names()) np.save(str(filename), sequence_matrix, allow_pickle=False) return SequenceDataset(labels={ label: list(label_dict.keys()) for label, label_dict in labels.items() }, filenames=[filename], file_size=sequence_count)
def generate_receptor_dataset(receptor_count: int, chain_1_length_probabilities: dict, chain_2_length_probabilities: dict, labels: dict, path: Path): """ Creates receptor_count receptors where the length of sequences in each chain is sampled independently for each sequence from chain_n_length_probabilities distribution. The labels are also randomly assigned to receptors from the distribution given in labels. In this case, labels are multi-class, so each receptor will get one class from each label. This means that negative classes for the labels should be included as well in the specification. chain 1 and 2 in this case refer to alpha and beta chain of a T-cell receptor. An example of input parameters is given below: receptor_count: 100 # generate 100 TRABReceptors chain_1_length_probabilities: 14: 0.8 # 80% of all generated sequences for all receptors (for chain 1) will have length 14 15: 0.2 # 20% of all generated sequences across all receptors (for chain 1) will have length 15 chain_2_length_probabilities: 14: 0.8 # 80% of all generated sequences for all receptors (for chain 2) will have length 14 15: 0.2 # 20% of all generated sequences across all receptors (for chain 2) will have length 15 labels: epitope1: # label name True: 0.5 # 50% of the receptors will have class True False: 0.5 # 50% of the receptors will have class False epitope2: # next label with classes that will be assigned to receptors independently of the previous label or other parameters 1: 0.3 # 30% of the generated receptors will have class 1 0: 0.7 # 70% of the generated receptors will have class 0 """ RandomDatasetGenerator._check_receptor_dataset_generation_params( receptor_count, chain_1_length_probabilities, chain_2_length_probabilities, labels, path) alphabet = EnvironmentSettings.get_sequence_alphabet() PathBuilder.build(path) get_random_sequence = lambda proba, chain, id: ReceptorSequence( "".join( random.choices(alphabet, k=random.choices(list(proba.keys()), proba.values())[0])), metadata=SequenceMetadata(count=1, v_subgroup=chain + "V1", v_gene=chain + "V1-1", v_allele=chain + "V1-1*01", j_subgroup=chain + "J1", j_gene=chain + "J1-1", j_allele=chain + "J1-1*01", chain=chain, cell_id=id)) receptors = [ TCABReceptor(alpha=get_random_sequence( chain_1_length_probabilities, "TRA", i), beta=get_random_sequence(chain_2_length_probabilities, "TRB", i), metadata={ **{ label: random.choices(list(label_dict.keys()), label_dict.values(), k=1)[0] for label, label_dict in labels.items() }, **{ "subject": f"subj_{i + 1}" } }) for i in range(receptor_count) ] filename = path / "batch01.npy" receptor_matrix = np.core.records.fromrecords( [receptor.get_record() for receptor in receptors], names=TCABReceptor.get_record_names()) np.save(str(filename), receptor_matrix, allow_pickle=False) return ReceptorDataset(labels={ label: list(label_dict.keys()) for label, label_dict in labels.items() }, filenames=[filename], file_size=receptor_count, element_class_name=type(receptors[0]).__name__ if len(receptors) > 0 else None)
def get_file_path(cache_type=None): file_path = EnvironmentSettings.get_cache_path(cache_type) / "files" PathBuilder.build(file_path) return file_path
def test_repertoire_flattened(self): path = EnvironmentSettings.root_path / "test/tmp/onehot_recep_flat/" PathBuilder.build(path) dataset, lc = self._construct_test_repertoiredataset(path, positional=False) encoder = OneHotEncoder.build_object(dataset, **{"use_positional_info": False, "distance_to_seq_middle": None, "flatten": True}) encoded_data = encoder.encode(dataset, EncoderParams( result_path=path, label_config=lc, pool_size=1, learn_model=True, model={}, filename="dataset.pkl" )) self.assertTrue(isinstance(encoded_data, RepertoireDataset)) onehot_a = [1.0] + [0.0] * 19 onehot_t = [0.0] * 16 + [1.0] + [0] * 3 onehot_empty = [0] * 20 self.assertListEqual(list(encoded_data.encoded_data.examples[0]), onehot_a+onehot_a+onehot_a+onehot_a+onehot_a+onehot_t+onehot_a+onehot_empty+onehot_a+onehot_t+onehot_a+onehot_empty) self.assertListEqual(list(encoded_data.encoded_data.examples[1]), onehot_a+onehot_t+onehot_a+onehot_empty+onehot_t+onehot_a+onehot_a+onehot_empty+onehot_empty+onehot_empty+onehot_empty+onehot_empty) self.assertListEqual(list(encoded_data.encoded_data.feature_names), [f"{seq}_{pos}_{char}" for seq in range(3) for pos in range(4) for char in EnvironmentSettings.get_sequence_alphabet()]) shutil.rmtree(path)