def get_model(name: str, device: Union[None, str, torch.device]) -> Any: if name in ["bert_from_publication", "seqvec_from_publication"]: return BasicAnnotationExtractor(name, device) elif name == "esm1v": return name_to_embedder[name](ensemble_id=1, device=device) elif name in name_to_embedder: return name_to_embedder[name](device=device) elif name == "pb_tucker": return PBTucker(get_model_file("pb_tucker", "model_file"), device) elif name == "deepblast": model_file = get_model_file("deepblast", "model_file") return LightningAligner.load_from_checkpoint(model_file).to(device) else: raise ValueError(f"Unknown name {name}")
def __init__(self, model_type: str, device: Union[None, str, torch.device] = None, **kwargs): """ Initialize annotation extractor. Must define non-positional arguments for paths of files. :param model_file: path of conservation inference model checkpoint file (only CNN-architecture from paper) """ self._options = kwargs self._model_type = model_type self._device = get_device(device) # Create un-trained (raw) model and ensure self._model_type is valid self._conservation_model = ConservationCNN().to(self._device) # Download the checkpoint files if needed if not self._options.get("model_file"): self._options["model_file"] = get_model_file(model=f"prott5cons", file="model_file") self.model_file = self._options["model_file"] # load pre-trained weights for annotation machines conservation_state = torch.load(self.model_file, map_location=self._device) # load pre-trained weights into raw model self._conservation_model.load_state_dict( conservation_state["state_dict"]) # ensure that model is in evaluation mode (important for batchnorm and dropout) self._conservation_model.eval()
def run(**kwargs): """ Run embedding protocol Parameters ---------- kwargs arguments (* denotes optional): sequences_file: Where sequences live prefix: Output prefix for all generated files protocol: Which embedder to use mapping_file: the mapping file generated by the pipeline when remapping indexes stage_name: The stage name Returns ------- Dictionary with results of stage """ check_required( kwargs, [ "protocol", "prefix", "stage_name", "remapped_sequences_file", "mapping_file" ], ) if kwargs["protocol"] not in PROTOCOLS: raise InvalidParameterError( "Invalid protocol selection: {}. Valid protocols are: {}".format( kwargs["protocol"], ", ".join(PROTOCOLS.keys()))) embedder_class = PROTOCOLS[kwargs["protocol"]] if embedder_class == UniRepEmbedder and kwargs.get("use_cpu") is not None: raise InvalidParameterError( "UniRep does not support configuring `use_cpu`") result_kwargs = deepcopy(kwargs) # Download necessary files if needed # noinspection PyProtectedMember for file in embedder_class._necessary_files: if not result_kwargs.get(file): result_kwargs[file] = get_model_file(model=embedder_class.name, file=file) # noinspection PyProtectedMember for directory in embedder_class._necessary_directories: if not result_kwargs.get(directory): result_kwargs[directory] = get_model_directories_from_zip( model=embedder_class.name, directory=directory) result_kwargs.setdefault("max_amino_acids", DEFAULT_MAX_AMINO_ACIDS[kwargs["protocol"]]) file_manager = get_file_manager(**kwargs) embedder: EmbedderInterface = embedder_class(**result_kwargs) _check_transform_embeddings_function(embedder, result_kwargs) return embed_and_write_batched(embedder, file_manager, result_kwargs, kwargs.get("half_precision", False))
def __init__(self, device: Union[None, str, torch.device] = None, **kwargs): self._options = kwargs self._device = get_device(device) # Download the checkpoint files if needed for file in self.necessary_files: if not self._options.get(file): self._options[file] = get_model_file(model='tmbed', file=file) self._models = [] for model_idx in range(5): # Create blank model model = TmbedModel() # Get model file model_file = self._options[f'model_{model_idx}_file'] # Load pre-trained weights model.load_state_dict(torch.load(model_file)['model']) # Finalize model model = model.eval().to(self._device) # Add to model list self._models.append(model) self._decoder = Decoder()
def pb_tucker(file_manager: FileManagerInterface, result_kwargs: Dict[str, Any]) -> Dict[str, Any]: device = get_device(result_kwargs.get("device")) if "model_file" not in result_kwargs: model_file = get_model_file("pb_tucker", "model_file") else: model_file = result_kwargs["model_file"] pb_tucker = PBTucker(model_file, device) reduced_embeddings_file_path = result_kwargs["reduced_embeddings_file"] projected_reduced_embeddings_file_path = file_manager.create_file( result_kwargs.get("prefix"), result_kwargs.get("stage_name"), "projected_reduced_embeddings_file", extension=".csv", ) result_kwargs[ "projected_reduced_embeddings_file"] = projected_reduced_embeddings_file_path with h5py.File(reduced_embeddings_file_path, "r") as input_embeddings, h5py.File( projected_reduced_embeddings_file_path, "w") as output_embeddings: for h5_id, reduced_embedding in input_embeddings.items(): output_embeddings[h5_id] = pb_tucker.project_reduced_embedding( reduced_embedding) return result_kwargs
def __init__(self, model: str, device: Union[None, str, torch.device] = None, **kwargs): """ Initialize annotation extractor. Must define non-positional arguments for paths of files. :param membrane_checkpoint_file: path of the membrane boundness inference model checkpoint file :param subcellular_location_checkpoint_file: path of the subcellular location inference model checkpoint file """ self._options = kwargs self._device = get_device(device) # Create un-trained (raw) model self._subcellular_location_model = LightAttention(output_dim=10).to(self._device) self._membrane_model = LightAttention(output_dim=2).to(self._device) self._subcellular_location_checkpoint_file = self._options.get('subcellular_location_checkpoint_file') self._membrane_checkpoint_file = self._options.get('membrane_checkpoint_file') self._device = get_device(device) # Download files if needed for file in self.necessary_files: if not self._options.get(file): self._options[file] = get_model_file(model=model, file=file) # load pre-trained weights for annotation machines subcellular_state = torch.load(self._subcellular_location_checkpoint_file, map_location=self._device) membrane_state = torch.load(self._membrane_checkpoint_file, map_location=self._device) # load pre-trained weights into raw model self._subcellular_location_model.load_state_dict(subcellular_state['state_dict']) self._membrane_model.load_state_dict(membrane_state['state_dict']) # ensure that model is in evaluation mode (important for batchnorm and dropout) self._subcellular_location_model.eval() self._membrane_model.eval()
def __init__(self, device: Union[None, str, torch.device] = None, **kwargs): """ Initializer accepts location of a pre-trained model and options """ self._options = kwargs self._device = get_device(device) # Special case because SeqVec can currently be used with either a model directory or two files if self.__class__.__name__ == "SeqVecEmbedder": # No need to download weights_file/options_file if model_directory is given if "model_directory" in self._options: return files_loaded = 0 for file in self.necessary_files: if not self._options.get(file): self._options[file] = get_model_file(model=self.name, file=file) files_loaded += 1 for directory in self.necessary_directories: if not self._options.get(directory): self._options[directory] = get_model_directories_from_zip( model=self.name, directory=directory ) files_loaded += 1 total_necessary = len(self.necessary_files) + len(self.necessary_directories) if 0 < files_loaded < total_necessary: logger.warning( f"You should pass either all necessary files or directories, or none, " f"while you provide {files_loaded} of {total_necessary}" )
def __init__(self, model_type: str, device: Union[None, str, torch.device] = None, **kwargs): """ Initialize annotation extractor. Must define non-positional arguments for paths of files. :param secondary_structure_checkpoint_file: path of secondary structure inference model checkpoint file :param subcellular_location_checkpoint_file: path of the subcellular location inference model checkpoint file """ self._options = kwargs self._model_type = model_type self._device = get_device(device) # Create un-trained (raw) model and ensure self._model_type is valid if self._model_type == "seqvec_from_publication": self._subcellular_location_model = SUBCELL_FNN().to(self._device) elif self._model_type == "bert_from_publication": # Drop batchNorm for ProtTrans models self._subcellular_location_model = SUBCELL_FNN( use_batch_norm=False).to(self._device) else: print("You first need to define your custom model architecture.") raise NotImplementedError # Download the checkpoint files if needed for file in self.necessary_files: if not self._options.get(file): self._options[file] = get_model_file( model=f"{self._model_type}_annotations_extractors", file=file) self._secondary_structure_checkpoint_file = self._options[ 'secondary_structure_checkpoint_file'] self._subcellular_location_checkpoint_file = self._options[ 'subcellular_location_checkpoint_file'] # Read in pre-trained model self._secondary_structure_model = SECSTRUCT_CNN().to(self._device) # load pre-trained weights for annotation machines subcellular_state = torch.load( self._subcellular_location_checkpoint_file, map_location=self._device) secondary_structure_state = torch.load( self._secondary_structure_checkpoint_file, map_location=self._device) # load pre-trained weights into raw model self._subcellular_location_model.load_state_dict( subcellular_state['state_dict']) self._secondary_structure_model.load_state_dict( secondary_structure_state['state_dict']) # ensure that model is in evaluation mode (important for batchnorm and dropout) self._subcellular_location_model.eval() self._secondary_structure_model.eval()
def tmbed(**kwargs) -> Dict[str, Any]: ''' Protocol extracts membrane residues from "embeddings_file". Embeddings must have been generated with ProtT5-XL-U50. ''' check_required(kwargs, ['embeddings_file', 'remapped_sequences_file']) result_kwargs = deepcopy(kwargs) file_manager = get_file_manager(**kwargs) # Download necessary files if needed for file in TmbedAnnotationExtractor.necessary_files: if not result_kwargs.get(file): result_kwargs[file] = get_model_file(model='tmbed', file=file) tmbed_extractor = TmbedAnnotationExtractor(**result_kwargs) # Try to create final file (if this fails, now is better than later) membrane_residues_predictions_file_path = file_manager.create_file(result_kwargs.get('prefix'), result_kwargs.get('stage_name'), 'membrane_residues_predictions_file', extension='.fasta') result_kwargs['membrane_residues_predictions_file'] = membrane_residues_predictions_file_path tmbed_sequences = [] with h5py.File(result_kwargs['embeddings_file'], 'r') as embedding_file: for protein_sequence in read_fasta(result_kwargs['remapped_sequences_file']): embedding = np.array(embedding_file[protein_sequence.id]) # Add batch dimension (until we support batch processing) embedding = embedding[None, ] # Sequence lengths (only a single sequence for now) lengths = [len(protein_sequence.seq)] annotations = tmbed_extractor.get_membrane_residues(embedding, lengths) # Gratuitous loop (only a single item for now) # Needs to be changed for batch mode to deepcopy different protein sequences for annotation in annotations: tmbed_sequence = deepcopy(protein_sequence) tmbed_sequence.seq = Seq(convert_list_of_enum_to_string(annotation.membrane_residues)) tmbed_sequences.append(tmbed_sequence) # Write file write_fasta_file(tmbed_sequences, membrane_residues_predictions_file_path) return result_kwargs
def __init__(self, ensemble_id: int, device: Union[None, str, torch.device] = None, **kwargs): """You must pass the number of the model (1-5) as first parameter, though you can override the weights file with model_file""" assert ensemble_id in range(1, 6), "The model number must be in 1-5" self.ensemble_id = ensemble_id # EmbedderInterface assumes static model files, but we need to dynamically select one of the five if "model_file" not in kwargs: kwargs["model_file"] = get_model_file( model=self.name, file=f"model_{ensemble_id}_file") super().__init__(device, **kwargs)
def test_tucker(pytestconfig, device): bert_embeddings_file = pytestconfig.rootpath.joinpath( "test-data/reference-embeddings").joinpath( ProtTransBertBFDEmbedder.name + ".npz") bert_embeddings = numpy.load(bert_embeddings_file) tucker_embeddings_file = pytestconfig.rootpath.joinpath( "test-data/reference-embeddings").joinpath(PBTucker.name + ".npz") tucker_embeddings = numpy.load(tucker_embeddings_file) pb_tucker = PBTucker(get_model_file("pb_tucker", "model_file"), device) for name, embedding in bert_embeddings.items(): reduced_embedding = embedding.mean(axis=0) tucker_embedding = pb_tucker.project_reduced_embedding( reduced_embedding) assert numpy.allclose(tucker_embeddings[name], tucker_embedding, rtol=1.0e-3, atol=1.0e-5), name
def prott5cons(model: str, **kwargs) -> Dict[str, Any]: """ Protocol extracts conservation from "embeddings_file". Embeddings can only be generated with ProtT5-XL-U50. :param model: "t5_xl_u50_conservation". Used to download files """ check_required(kwargs, ['embeddings_file', 'mapping_file', 'remapped_sequences_file']) result_kwargs = deepcopy(kwargs) file_manager = get_file_manager(**kwargs) # Download necessary files if needed for file in ProtT5consAnnotationExtractor.necessary_files: if not result_kwargs.get(file): result_kwargs[file] = get_model_file(model=model, file=file) annotation_extractor = ProtT5consAnnotationExtractor(**result_kwargs) # mapping file will be needed for protein-wide annotations mapping_file = read_mapping_file(result_kwargs["mapping_file"]) # Try to create final files (if this fails, now is better than later conservation_predictions_file_path = file_manager.create_file(result_kwargs.get('prefix'), result_kwargs.get('stage_name'), 'conservation_predictions_file', extension='.fasta') result_kwargs['conservation_predictions_file'] = conservation_predictions_file_path cons_sequences = list() with h5py.File(result_kwargs['embeddings_file'], 'r') as embedding_file: for protein_sequence in read_fasta(result_kwargs['remapped_sequences_file']): embedding = np.array(embedding_file[protein_sequence.id]) annotations = annotation_extractor.get_conservation(embedding) cons_sequence = deepcopy(protein_sequence) cons_sequence.seq = Seq(convert_list_of_enum_to_string(annotations.conservation)) cons_sequences.append(cons_sequence) # Write files write_fasta_file(cons_sequences, conservation_predictions_file_path) return result_kwargs
def run(**kwargs): """ Run embedding protocol Parameters ---------- kwargs arguments (* denotes optional): sequences_file: Where sequences live prefix: Output prefix for all generated files protocol: Which embedder to use mapping_file: the mapping file generated by the pipeline when remapping indexes stage_name: The stage name Returns ------- Dictionary with results of stage """ embedder_class, result_kwargs = prepare_kwargs(**kwargs) # Download necessary files if needed # noinspection PyProtectedMember for file in embedder_class.necessary_files: if not result_kwargs.get(file): result_kwargs[file] = get_model_file(model=embedder_class.name, file=file) # noinspection PyProtectedMember for directory in embedder_class.necessary_directories: if not result_kwargs.get(directory): result_kwargs[directory] = get_model_directories_from_zip( model=embedder_class.name, directory=directory) file_manager = get_file_manager(**kwargs) embedder: EmbedderInterface = embedder_class(**result_kwargs) _check_transform_embeddings_function(embedder, result_kwargs) return embed_and_write_batched(embedder, file_manager, result_kwargs, kwargs.get("half_precision", False))
def __init__(self, device: Union[None, str, torch.device] = None, **kwargs): """ Initialize annotation extractor. Must define non-positional arguments for paths of files. :param model_file: path of bindEmbed21DL inference model checkpoint file """ self._options = kwargs self._device = get_device(device) # Create un-trained (raw) models self._binding_residue_model_1 = BindingResiduesCNN().to(self._device) self._binding_residue_model_2 = BindingResiduesCNN().to(self._device) self._binding_residue_model_3 = BindingResiduesCNN().to(self._device) self._binding_residue_model_4 = BindingResiduesCNN().to(self._device) self._binding_residue_model_5 = BindingResiduesCNN().to(self._device) # Download the checkpoint files if needed for file in self.necessary_files: if not self._options.get(file): self._options[file] = get_model_file(model=f"bindembed21dl", file=file) self.model_file_1 = self._options['model_1_file'] self.model_file_2 = self._options['model_2_file'] self.model_file_3 = self._options['model_3_file'] self.model_file_4 = self._options['model_4_file'] self.model_file_5 = self._options['model_5_file'] # load pre-trained weights for annotation machines binding_residue_state_1 = torch.load(self.model_file_1, map_location=self._device) binding_residue_state_2 = torch.load(self.model_file_1, map_location=self._device) binding_residue_state_3 = torch.load(self.model_file_1, map_location=self._device) binding_residue_state_4 = torch.load(self.model_file_1, map_location=self._device) binding_residue_state_5 = torch.load(self.model_file_1, map_location=self._device) # load pre-trained weights into raw model self._binding_residue_model_1.load_state_dict( binding_residue_state_1['state_dict']) self._binding_residue_model_2.load_state_dict( binding_residue_state_2['state_dict']) self._binding_residue_model_3.load_state_dict( binding_residue_state_3['state_dict']) self._binding_residue_model_4.load_state_dict( binding_residue_state_4['state_dict']) self._binding_residue_model_5.load_state_dict( binding_residue_state_5['state_dict']) # ensure that model is in evaluation mode (important for batchnorm and dropout) self._binding_residue_model_1.eval() self._binding_residue_model_2.eval() self._binding_residue_model_3.eval() self._binding_residue_model_4.eval() self._binding_residue_model_5.eval()
def deepblast(**kwargs) -> Dict[str, Any]: """Sequence-Sequence alignments with DeepBLAST DeepBLAST learned structural alignments from sequence https://github.com/flatironinstitute/deepblast https://www.biorxiv.org/content/10.1101/2020.11.03.365932v1 """ # TODO: Fix that logic before merging if "transferred_annotations_file" not in kwargs and "pairings_file" not in kwargs: raise MissingParameterError( "You need to specify either 'transferred_annotations_file' or 'pairings_file' for DeepBLAST" ) if "transferred_annotations_file" in kwargs and "pairings_file" in kwargs: raise InvalidParameterError( "You can't specify both 'transferred_annotations_file' and 'pairings_file' for DeepBLAST" ) result_kwargs = deepcopy(kwargs) file_manager = get_file_manager(**kwargs) # This stays below 8GB, so it should be a good default batch_size = result_kwargs.setdefault("batch_size", 50) if "device" in result_kwargs: device = torch.device(result_kwargs["device"]) if device.type != "cuda": raise RuntimeError( f"You can only run DeepBLAST on a CUDA-compatible GPU, not on {device.type}" ) else: if not torch.cuda.is_available(): raise RuntimeError( "DeepBLAST requires a CUDA-compatible GPU, but none was found") device = torch.device("cuda") mapping_file = read_mapping_file(result_kwargs["mapping_file"]) mapping = { str(remapped): original for remapped, original in mapping_file[["original_id"]].itertuples() } query_by_id = { mapping[entry.name]: str(entry.seq) for entry in SeqIO.parse(result_kwargs["remapped_sequences_file"], "fasta") } # You can either provide a set of pairing or use the output of k-nn with a fasta file for the reference embeddings if "pairings_file" in result_kwargs: pairings_file = read_csv(result_kwargs["pairings_file"]) pairings = list(pairings_file[["query", "target"]].itertuples(index=False)) target_by_id = query_by_id else: transferred_annotations_file = read_csv( result_kwargs["transferred_annotations_file"]) pairings = [] for _, row in transferred_annotations_file.iterrows(): query = row["original_id"] for target in row.filter(regex="k_nn_.*_identifier"): pairings.append((query, target)) target_by_id = {} for entry in read_fasta(result_kwargs["reference_fasta_file"]): target_by_id[entry.name] = str(entry.seq[:]) # Create one output file per query result_kwargs["alignment_files"] = dict() for query in set(i for i, _ in pairings): filename = file_manager.create_file( result_kwargs.get("prefix"), result_kwargs.get("stage_name"), f"{slugify(query, lowercase=False)}_alignments", extension=".a2m", ) result_kwargs["alignment_files"][query] = filename unknown_queries = set(list(zip(*pairings))[0]) - set(query_by_id.keys()) if unknown_queries: raise ValueError(f"Unknown query sequences: {unknown_queries}") unknown_targets = set(list(zip(*pairings))[1]) - set(target_by_id.keys()) if unknown_targets: raise ValueError(f"Unknown target sequences: {unknown_targets}") # Load the pretrained model if "model_file" not in result_kwargs: model_file = get_model_file("deepblast", "model_file") else: model_file = result_kwargs["model_file"] alignments = deepblast_align(pairings, query_by_id, target_by_id, model_file, device, batch_size) for query, alignments in itertools.groupby(alignments, key=lambda i: i[0]): _, targets, queries_aligned, targets_aligned = list(zip(*alignments)) padded_query, padded_targets = pairwise_alignments_to_msa( queries_aligned, targets_aligned) with open(result_kwargs["alignment_files"][query], "w") as fp: fp.write(f">{query}\n") fp.write(f"{padded_query}\n") for target, padded_target in zip(targets, padded_targets): fp.write(f">{target}\n") fp.write(f"{padded_target}\n") return result_kwargs
), ( lambda: ProtTransBertBFDEmbedder(), lambda: BasicAnnotationExtractor("bert_from_publication"), 0.5387755102040817, ), ( lambda: ProtTransT5XLU50Embedder(half_precision_model=True), lambda: BasicAnnotationExtractor("t5_xl_u50_from_publication"), 0.6285714285714286, ), ( lambda: ProtTransT5XLU50Embedder(half_precision_model=True), lambda: LightAttentionAnnotationExtractor( subcellular_location_checkpoint_file=get_model_file( "la_prott5", "subcellular_location_checkpoint_file" ), membrane_checkpoint_file=get_model_file( "la_prott5", "membrane_checkpoint_file" ), ), 0.6551020408163265, ), ], ) def test_basic_annotation_extractor( pytestconfig, get_embedder: Callable[[], EmbedderInterface], get_extractor: Callable[ [], Union[BasicAnnotationExtractor, LightAttentionAnnotationExtractor] ],
def bindembed21dl(**kwargs) -> Dict[str, Any]: """ Protocol extracts binding residues from "embeddings_file". Results guaranteed only with ProtT5-XL-U50 embeddings. :return: """ check_required(kwargs, ['embeddings_file', 'mapping_file', 'remapped_sequences_file']) result_kwargs = deepcopy(kwargs) file_manager = get_file_manager(**kwargs) # Download necessary files if needed for file in BindEmbed21DLAnnotationExtractor.necessary_files: if not result_kwargs.get(file): result_kwargs[file] = get_model_file(model="bindembed21dl", file=file) annotation_extractor = BindEmbed21DLAnnotationExtractor(**result_kwargs) # Try to create final files (if this fails, now is better than later metal_binding_predictions_file_path = file_manager.create_file(result_kwargs.get('prefix'), result_kwargs.get('stage_name'), 'metal_binding_predictions_file', extension='.fasta') result_kwargs['metal_binding_predictions_file'] = metal_binding_predictions_file_path nuc_binding_predictions_file_path = file_manager.create_file(result_kwargs.get('prefix'), result_kwargs.get('stage_name'), 'nucleic_acid_binding_predictions_file', extension='.fasta') result_kwargs['binding_residue_predictions_file'] = nuc_binding_predictions_file_path small_binding_predictions_file_path = file_manager.create_file(result_kwargs.get('prefix'), result_kwargs.get('stage_name'), 'small_molecule_binding_predictions_file', extension='.fasta') result_kwargs['binding_residue_predictions_file'] = small_binding_predictions_file_path metal_sequences = list() nuc_sequences = list() small_sequences = list() with h5py.File(result_kwargs['embeddings_file'], 'r') as embedding_file: for protein_sequence in read_fasta(result_kwargs['remapped_sequences_file']): embedding = np.array(embedding_file[protein_sequence.id]) annotations = annotation_extractor.get_binding_residues(embedding) metal_sequence = deepcopy(protein_sequence) nuc_sequence = deepcopy(protein_sequence) small_sequence = deepcopy(protein_sequence) metal_sequence.seq = Seq(convert_list_of_enum_to_string(annotations.metal_ion)) nuc_sequence.seq = Seq(convert_list_of_enum_to_string(annotations.nucleic_acids)) small_sequence.seq = Seq(convert_list_of_enum_to_string(annotations.small_molecules)) metal_sequences.append(metal_sequence) nuc_sequences.append(nuc_sequence) small_sequences.append(small_sequence) # Write files write_fasta_file(metal_sequences, metal_binding_predictions_file_path) write_fasta_file(nuc_sequences, nuc_binding_predictions_file_path) write_fasta_file(small_sequences, small_binding_predictions_file_path) return result_kwargs
def bindembed21(**kwargs) -> Dict[str, Any]: """ Protocol extracts binding residues from "alignment_result_file" if possible, and from "embeddings_file", otherwise. :param kwargs: :return: """ check_required(kwargs, ['alignment_results_file', 'embeddings_file', 'mapping_file', 'remapped_sequences_file']) result_kwargs = deepcopy(kwargs) file_manager = get_file_manager(**kwargs) # Download necessary files if needed # for HBI for directory in BindEmbed21HBIAnnotationExtractor.necessary_directories: if not result_kwargs.get(directory): result_kwargs[directory] = get_model_directories_from_zip(model="bindembed21hbi", directory=directory) # for DL for file in BindEmbed21DLAnnotationExtractor.necessary_files: if not result_kwargs.get(file): result_kwargs[file] = get_model_file(model="bindembed21dl", file=file) hbi_extractor = BindEmbed21HBIAnnotationExtractor(**result_kwargs) dl_extractor = BindEmbed21DLAnnotationExtractor(**result_kwargs) # Try to create final files (if this fails, now is better than later metal_binding_predictions_file_path = file_manager.create_file(result_kwargs.get('prefix'), result_kwargs.get('stage_name'), 'metal_binding_predictions_file', extension='.fasta') result_kwargs['metal_binding_predictions_file'] = metal_binding_predictions_file_path nuc_binding_predictions_file_path = file_manager.create_file(result_kwargs.get('prefix'), result_kwargs.get('stage_name'), 'nucleic_acid_binding_predictions_file', extension='.fasta') result_kwargs['binding_residue_predictions_file'] = nuc_binding_predictions_file_path small_binding_predictions_file_path = file_manager.create_file(result_kwargs.get('prefix'), result_kwargs.get('stage_name'), 'small_molecule_binding_predictions_file', extension='.fasta') result_kwargs['binding_residue_predictions_file'] = small_binding_predictions_file_path metal_sequences = list() nuc_sequences = list() small_sequences = list() alignment_results = read_csv(result_kwargs['alignment_results_file'], sep='\t', dtype={'query': 'str', 'target': 'str'}) alignment_results = alignment_results[alignment_results['eval'] < 1E-3].copy() with h5py.File(result_kwargs['embeddings_file'], 'r') as embedding_file: for protein_sequence in read_fasta(result_kwargs['remapped_sequences_file']): # get HBI hit for this query hits = alignment_results[alignment_results['query'].str.match(str(protein_sequence.id))].copy() hits_min_eval = hits[hits['eval'] == min(hits['eval'])] hit_max_pide = hits_min_eval[hits_min_eval['fident'] == max(hits_min_eval['fident'])] metal_sequence = deepcopy(protein_sequence) nuc_sequence = deepcopy(protein_sequence) small_sequence = deepcopy(protein_sequence) hbi_annotations = hbi_extractor.get_binding_residues(hit_max_pide.iloc[0].to_dict()) metal_inference = convert_list_of_enum_to_string(hbi_annotations.metal_ion) nuc_inference = convert_list_of_enum_to_string(hbi_annotations.nucleic_acids) small_inference = convert_list_of_enum_to_string(hbi_annotations.small_molecules) # some part of the sequence was predicted using HBI --> save output and don't run DL method if 'M' in metal_inference or 'N' in nuc_inference or 'S' in small_inference: metal_sequence.seq = Seq(metal_inference) nuc_sequence.seq = Seq(nuc_inference) small_sequence.seq = Seq(small_inference) # no inference containing binding annotations was made --> run bindEmbed21DL else: embedding = np.array(embedding_file[protein_sequence.id]) annotations = dl_extractor.get_binding_residues(embedding) metal_sequence = deepcopy(protein_sequence) nuc_sequence = deepcopy(protein_sequence) small_sequence = deepcopy(protein_sequence) metal_sequence.seq = Seq(convert_list_of_enum_to_string(annotations.metal_ion)) nuc_sequence.seq = Seq(convert_list_of_enum_to_string(annotations.nucleic_acids)) small_sequence.seq = Seq(convert_list_of_enum_to_string(annotations.small_molecules)) metal_sequences.append(metal_sequence) nuc_sequences.append(nuc_sequence) small_sequences.append(small_sequence) # Write files write_fasta_file(metal_sequences, metal_binding_predictions_file_path) write_fasta_file(nuc_sequences, nuc_binding_predictions_file_path) write_fasta_file(small_sequences, small_binding_predictions_file_path) return result_kwargs