def __init__(self, model: str, device: Union[None, str, torch.device] = None, **kwargs): """ Initialize annotation extractor. Must define non-positional arguments for paths of files. :param membrane_checkpoint_file: path of the membrane boundness inference model checkpoint file :param subcellular_location_checkpoint_file: path of the subcellular location inference model checkpoint file """ self._options = kwargs self._device = get_device(device) # Create un-trained (raw) model self._subcellular_location_model = LightAttention(output_dim=10).to(self._device) self._membrane_model = LightAttention(output_dim=2).to(self._device) self._subcellular_location_checkpoint_file = self._options.get('subcellular_location_checkpoint_file') self._membrane_checkpoint_file = self._options.get('membrane_checkpoint_file') self._device = get_device(device) # Download files if needed for file in self.necessary_files: if not self._options.get(file): self._options[file] = get_model_file(model=model, file=file) # load pre-trained weights for annotation machines subcellular_state = torch.load(self._subcellular_location_checkpoint_file, map_location=self._device) membrane_state = torch.load(self._membrane_checkpoint_file, map_location=self._device) # load pre-trained weights into raw model self._subcellular_location_model.load_state_dict(subcellular_state['state_dict']) self._membrane_model.load_state_dict(membrane_state['state_dict']) # ensure that model is in evaluation mode (important for batchnorm and dropout) self._subcellular_location_model.eval() self._membrane_model.eval()
def __init__(self, device: Union[None, str, torch.device] = None, **kwargs): """ Initializer accepts location of a pre-trained model and options """ self._options = kwargs self._device = get_device(device) # Special case because SeqVec can currently be used with either a model directory or two files if self.__class__.__name__ == "SeqVecEmbedder": # No need to download weights_file/options_file if model_directory is given if "model_directory" in self._options: return files_loaded = 0 for file in self.necessary_files: if not self._options.get(file): self._options[file] = get_model_file(model=self.name, file=file) files_loaded += 1 for directory in self.necessary_directories: if not self._options.get(directory): self._options[directory] = get_model_directories_from_zip( model=self.name, directory=directory ) files_loaded += 1 total_necessary = len(self.necessary_files) + len(self.necessary_directories) if 0 < files_loaded < total_necessary: logger.warning( f"You should pass either all necessary files or directories, or none, " f"while you provide {files_loaded} of {total_necessary}" )
def __init__(self, model_type: str, device: Union[None, str, torch.device] = None, **kwargs): """ Initialize annotation extractor. Must define non-positional arguments for paths of files. :param model_file: path of conservation inference model checkpoint file (only CNN-architecture from paper) """ self._options = kwargs self._model_type = model_type self._device = get_device(device) # Create un-trained (raw) model and ensure self._model_type is valid self._conservation_model = ConservationCNN().to(self._device) # Download the checkpoint files if needed if not self._options.get("model_file"): self._options["model_file"] = get_model_file(model=f"prott5cons", file="model_file") self.model_file = self._options["model_file"] # load pre-trained weights for annotation machines conservation_state = torch.load(self.model_file, map_location=self._device) # load pre-trained weights into raw model self._conservation_model.load_state_dict( conservation_state["state_dict"]) # ensure that model is in evaluation mode (important for batchnorm and dropout) self._conservation_model.eval()
def __init__(self, device: Union[None, str, torch.device] = None, **kwargs): self._options = kwargs self._device = get_device(device) # Download the checkpoint files if needed for file in self.necessary_files: if not self._options.get(file): self._options[file] = get_model_file(model='tmbed', file=file) self._models = [] for model_idx in range(5): # Create blank model model = TmbedModel() # Get model file model_file = self._options[f'model_{model_idx}_file'] # Load pre-trained weights model.load_state_dict(torch.load(model_file)['model']) # Finalize model model = model.eval().to(self._device) # Add to model list self._models.append(model) self._decoder = Decoder()
def pb_tucker(file_manager: FileManagerInterface, result_kwargs: Dict[str, Any]) -> Dict[str, Any]: device = get_device(result_kwargs.get("device")) if "model_file" not in result_kwargs: model_file = get_model_file("pb_tucker", "model_file") else: model_file = result_kwargs["model_file"] pb_tucker = PBTucker(model_file, device) reduced_embeddings_file_path = result_kwargs["reduced_embeddings_file"] projected_reduced_embeddings_file_path = file_manager.create_file( result_kwargs.get("prefix"), result_kwargs.get("stage_name"), "projected_reduced_embeddings_file", extension=".csv", ) result_kwargs[ "projected_reduced_embeddings_file"] = projected_reduced_embeddings_file_path with h5py.File(reduced_embeddings_file_path, "r") as input_embeddings, h5py.File( projected_reduced_embeddings_file_path, "w") as output_embeddings: for h5_id, reduced_embedding in input_embeddings.items(): output_embeddings[h5_id] = pb_tucker.project_reduced_embedding( reduced_embedding) return result_kwargs
def __init__(self, device: Union[None, str, torch.device] = None, **kwargs): """ Initialize annotation extractor. Must define non-positional arguments for paths of files. :param membrane_checkpoint_file: path of the membrane boundness inference model checkpoint file :param subcellular_location_checkpoint_file: path of the subcellular location inference model checkpoint file """ self._options = kwargs self._device = get_device(device) # Create un-trained (raw) model self._subcellular_location_model = LightAttention(output_dim=10).to( self._device) self._membrane_model = LightAttention(output_dim=2).to(self._device) self._subcellular_location_checkpoint_file = self._options.get( 'subcellular_location_checkpoint_file') self._membrane_checkpoint_file = self._options.get( 'membrane_checkpoint_file') self._device = get_device(device) # check that file paths are passed for file in self.necessary_files: if file not in self._options: raise MissingParameterError( 'Please provide subcellular_location_checkpoint_file and membrane_checkpoint_file paths as named ' 'parameters to the constructor. Mind that these should match the embeddings used, ' 'e.g.: prottrans_bert_bfd should use la_protbert weights') # load pre-trained weights for annotation machines subcellular_state = torch.load( self._subcellular_location_checkpoint_file, map_location=self._device) membrane_state = torch.load(self._membrane_checkpoint_file, map_location=self._device) # load pre-trained weights into raw model self._subcellular_location_model.load_state_dict( subcellular_state['state_dict']) self._membrane_model.load_state_dict(membrane_state['state_dict']) # ensure that model is in evaluation mode (important for batchnorm and dropout) self._subcellular_location_model.eval() self._membrane_model.eval()
def __init__(self, model_type: str, device: Union[None, str, torch.device] = None, **kwargs): """ Initialize annotation extractor. Must define non-positional arguments for paths of files. :param secondary_structure_checkpoint_file: path of secondary structure inference model checkpoint file :param subcellular_location_checkpoint_file: path of the subcellular location inference model checkpoint file """ self._options = kwargs self._model_type = model_type self._device = get_device(device) # Create un-trained (raw) model and ensure self._model_type is valid if self._model_type == "seqvec_from_publication": self._subcellular_location_model = SUBCELL_FNN().to(self._device) elif self._model_type == "bert_from_publication": # Drop batchNorm for ProtTrans models self._subcellular_location_model = SUBCELL_FNN( use_batch_norm=False).to(self._device) else: print("You first need to define your custom model architecture.") raise NotImplementedError # Download the checkpoint files if needed for file in self.necessary_files: if not self._options.get(file): self._options[file] = get_model_file( model=f"{self._model_type}_annotations_extractors", file=file) self._secondary_structure_checkpoint_file = self._options[ 'secondary_structure_checkpoint_file'] self._subcellular_location_checkpoint_file = self._options[ 'subcellular_location_checkpoint_file'] # Read in pre-trained model self._secondary_structure_model = SECSTRUCT_CNN().to(self._device) # load pre-trained weights for annotation machines subcellular_state = torch.load( self._subcellular_location_checkpoint_file, map_location=self._device) secondary_structure_state = torch.load( self._secondary_structure_checkpoint_file, map_location=self._device) # load pre-trained weights into raw model self._subcellular_location_model.load_state_dict( subcellular_state['state_dict']) self._secondary_structure_model.load_state_dict( secondary_structure_state['state_dict']) # ensure that model is in evaluation mode (important for batchnorm and dropout) self._subcellular_location_model.eval() self._secondary_structure_model.eval()
def __init__( self, device: Union[None, str, torch.device] = None, model_directory: Optional[str] = None, half_precision_model: bool = False, ): """Loads the Bert Model for Masked LM""" self.device = get_device(device) self._half_precision_model = half_precision_model if not model_directory: model_directory = get_model_directories_from_zip( model=ProtTransBertBFDEmbedder.name, directory="model_directory") self.tokenizer = BertTokenizer.from_pretrained(model_directory, do_lower_case=False) self.model = BertForMaskedLM.from_pretrained(model_directory) # Compute in half precision, which is a lot faster and saves us half the memory if self._half_precision_model: self.model = self.model.half() self.model = self.model.eval().to(self.device)
def __init__(self, device: Union[None, str, torch.device] = None, **kwargs): """ Initialize annotation extractor. Must define non-positional arguments for paths of files. :param model_file: path of bindEmbed21DL inference model checkpoint file """ self._options = kwargs self._device = get_device(device) # Create un-trained (raw) models self._binding_residue_model_1 = BindingResiduesCNN().to(self._device) self._binding_residue_model_2 = BindingResiduesCNN().to(self._device) self._binding_residue_model_3 = BindingResiduesCNN().to(self._device) self._binding_residue_model_4 = BindingResiduesCNN().to(self._device) self._binding_residue_model_5 = BindingResiduesCNN().to(self._device) # Download the checkpoint files if needed for file in self.necessary_files: if not self._options.get(file): self._options[file] = get_model_file(model=f"bindembed21dl", file=file) self.model_file_1 = self._options['model_1_file'] self.model_file_2 = self._options['model_2_file'] self.model_file_3 = self._options['model_3_file'] self.model_file_4 = self._options['model_4_file'] self.model_file_5 = self._options['model_5_file'] # load pre-trained weights for annotation machines binding_residue_state_1 = torch.load(self.model_file_1, map_location=self._device) binding_residue_state_2 = torch.load(self.model_file_1, map_location=self._device) binding_residue_state_3 = torch.load(self.model_file_1, map_location=self._device) binding_residue_state_4 = torch.load(self.model_file_1, map_location=self._device) binding_residue_state_5 = torch.load(self.model_file_1, map_location=self._device) # load pre-trained weights into raw model self._binding_residue_model_1.load_state_dict( binding_residue_state_1['state_dict']) self._binding_residue_model_2.load_state_dict( binding_residue_state_2['state_dict']) self._binding_residue_model_3.load_state_dict( binding_residue_state_3['state_dict']) self._binding_residue_model_4.load_state_dict( binding_residue_state_4['state_dict']) self._binding_residue_model_5.load_state_dict( binding_residue_state_5['state_dict']) # ensure that model is in evaluation mode (important for batchnorm and dropout) self._binding_residue_model_1.eval() self._binding_residue_model_2.eval() self._binding_residue_model_3.eval() self._binding_residue_model_4.eval() self._binding_residue_model_5.eval()
def run(**kwargs): """BETA: in-silico mutagenesis using BertForMaskedLM optional (see extract stage for details): * model_directory * device * half_precision * half_precision_model * temperature: temperature for softmax """ required_kwargs = [ "protocol", "prefix", "stage_name", "remapped_sequences_file", "mapping_file", ] check_required(kwargs, required_kwargs) result_kwargs = deepcopy(kwargs) if result_kwargs["protocol"] not in _PROTOCOLS: raise RuntimeError( f"Passed protocol {result_kwargs['protocol']}, but allowed are: {', '.join(_PROTOCOLS)}" ) temperature = result_kwargs.setdefault("temperature", 1) device = get_device(result_kwargs.get("device")) model_class: Type[ProtTransBertBFDMutagenesis] = _PROTOCOLS[ result_kwargs["protocol"] ] model = model_class( device, result_kwargs.get("model_directory"), result_kwargs.get("half_precision_model"), ) file_manager = get_file_manager() file_manager.create_stage(result_kwargs["prefix"], result_kwargs["stage_name"]) # The mapping file contains the corresponding ids in the same order sequences = [ str(entry.seq) for entry in SeqIO.parse(result_kwargs["remapped_sequences_file"], "fasta") ] mapping_file = read_mapping_file(result_kwargs["mapping_file"]) probabilities_all = dict() with tqdm(total=int(mapping_file["sequence_length"].sum())) as progress_bar: for sequence_id, original_id, sequence in zip( mapping_file.index, mapping_file["original_id"], sequences ): with torch.no_grad(): probabilities = model.get_sequence_probabilities( sequence, temperature, progress_bar=progress_bar ) for p in probabilities: assert math.isclose( 1, (sum(p.values()) - p["position"]), rel_tol=1e-6 ), "softmax values should add up to 1" probabilities_all[sequence_id] = probabilities residue_probabilities = probabilities_as_dataframe( mapping_file, probabilities_all, sequences ) probabilities_file = file_manager.create_file( result_kwargs.get("prefix"), result_kwargs.get("stage_name"), "residue_probabilities_file", extension=".csv", ) residue_probabilities.to_csv(probabilities_file, index=False) result_kwargs["residue_probabilities_file"] = probabilities_file return result_kwargs