Exemplo n.º 1
0
    def __init__(self, model: str, device: Union[None, str, torch.device] = None, **kwargs):
        """
        Initialize annotation extractor. Must define non-positional arguments for paths of files.

        :param membrane_checkpoint_file: path of the membrane boundness inference model checkpoint file
        :param subcellular_location_checkpoint_file: path of the subcellular location inference model checkpoint file
        """

        self._options = kwargs
        self._device = get_device(device)

        # Create un-trained (raw) model
        self._subcellular_location_model = LightAttention(output_dim=10).to(self._device)
        self._membrane_model = LightAttention(output_dim=2).to(self._device)

        self._subcellular_location_checkpoint_file = self._options.get('subcellular_location_checkpoint_file')
        self._membrane_checkpoint_file = self._options.get('membrane_checkpoint_file')
        self._device = get_device(device)

        # Download files if needed
        for file in self.necessary_files:
            if not self._options.get(file):
                self._options[file] = get_model_file(model=model, file=file)

        # load pre-trained weights for annotation machines
        subcellular_state = torch.load(self._subcellular_location_checkpoint_file, map_location=self._device)
        membrane_state = torch.load(self._membrane_checkpoint_file, map_location=self._device)

        # load pre-trained weights into raw model
        self._subcellular_location_model.load_state_dict(subcellular_state['state_dict'])
        self._membrane_model.load_state_dict(membrane_state['state_dict'])

        # ensure that model is in evaluation mode (important for batchnorm and dropout)
        self._subcellular_location_model.eval()
        self._membrane_model.eval()
Exemplo n.º 2
0
    def __init__(self, device: Union[None, str, torch.device] = None, **kwargs):
        """
        Initializer accepts location of a pre-trained model and options
        """
        self._options = kwargs
        self._device = get_device(device)

        # Special case because SeqVec can currently be used with either a model directory or two files
        if self.__class__.__name__ == "SeqVecEmbedder":
            # No need to download weights_file/options_file if model_directory is given
            if "model_directory" in self._options:
                return

        files_loaded = 0
        for file in self.necessary_files:
            if not self._options.get(file):
                self._options[file] = get_model_file(model=self.name, file=file)
                files_loaded += 1

        for directory in self.necessary_directories:
            if not self._options.get(directory):
                self._options[directory] = get_model_directories_from_zip(
                    model=self.name, directory=directory
                )

                files_loaded += 1

        total_necessary = len(self.necessary_files) + len(self.necessary_directories)
        if 0 < files_loaded < total_necessary:
            logger.warning(
                f"You should pass either all necessary files or directories, or none, "
                f"while you provide {files_loaded} of {total_necessary}"
            )
Exemplo n.º 3
0
    def __init__(self,
                 model_type: str,
                 device: Union[None, str, torch.device] = None,
                 **kwargs):
        """
        Initialize annotation extractor. Must define non-positional arguments for paths of files.

        :param model_file: path of conservation inference model checkpoint file (only CNN-architecture from paper)
        """

        self._options = kwargs
        self._model_type = model_type
        self._device = get_device(device)

        # Create un-trained (raw) model and ensure self._model_type is valid
        self._conservation_model = ConservationCNN().to(self._device)

        # Download the checkpoint files if needed
        if not self._options.get("model_file"):
            self._options["model_file"] = get_model_file(model=f"prott5cons",
                                                         file="model_file")

        self.model_file = self._options["model_file"]

        # load pre-trained weights for annotation machines
        conservation_state = torch.load(self.model_file,
                                        map_location=self._device)

        # load pre-trained weights into raw model
        self._conservation_model.load_state_dict(
            conservation_state["state_dict"])

        # ensure that model is in evaluation mode (important for batchnorm and dropout)
        self._conservation_model.eval()
Exemplo n.º 4
0
    def __init__(self, device: Union[None, str, torch.device] = None, **kwargs):

        self._options = kwargs
        self._device = get_device(device)

        # Download the checkpoint files if needed
        for file in self.necessary_files:
            if not self._options.get(file):
                self._options[file] = get_model_file(model='tmbed', file=file)

        self._models = []

        for model_idx in range(5):
            # Create blank model
            model = TmbedModel()
            # Get model file
            model_file = self._options[f'model_{model_idx}_file']
            # Load pre-trained weights
            model.load_state_dict(torch.load(model_file)['model'])
            # Finalize model
            model = model.eval().to(self._device)
            # Add to model list
            self._models.append(model)

        self._decoder = Decoder()
Exemplo n.º 5
0
def pb_tucker(file_manager: FileManagerInterface,
              result_kwargs: Dict[str, Any]) -> Dict[str, Any]:
    device = get_device(result_kwargs.get("device"))

    if "model_file" not in result_kwargs:
        model_file = get_model_file("pb_tucker", "model_file")
    else:
        model_file = result_kwargs["model_file"]
    pb_tucker = PBTucker(model_file, device)

    reduced_embeddings_file_path = result_kwargs["reduced_embeddings_file"]
    projected_reduced_embeddings_file_path = file_manager.create_file(
        result_kwargs.get("prefix"),
        result_kwargs.get("stage_name"),
        "projected_reduced_embeddings_file",
        extension=".csv",
    )
    result_kwargs[
        "projected_reduced_embeddings_file"] = projected_reduced_embeddings_file_path

    with h5py.File(reduced_embeddings_file_path,
                   "r") as input_embeddings, h5py.File(
                       projected_reduced_embeddings_file_path,
                       "w") as output_embeddings:
        for h5_id, reduced_embedding in input_embeddings.items():
            output_embeddings[h5_id] = pb_tucker.project_reduced_embedding(
                reduced_embedding)

    return result_kwargs
    def __init__(self,
                 device: Union[None, str, torch.device] = None,
                 **kwargs):
        """
        Initialize annotation extractor. Must define non-positional arguments for paths of files.

        :param membrane_checkpoint_file: path of the membrane boundness inference model checkpoint file
        :param subcellular_location_checkpoint_file: path of the subcellular location inference model checkpoint file
        """

        self._options = kwargs
        self._device = get_device(device)

        # Create un-trained (raw) model
        self._subcellular_location_model = LightAttention(output_dim=10).to(
            self._device)
        self._membrane_model = LightAttention(output_dim=2).to(self._device)

        self._subcellular_location_checkpoint_file = self._options.get(
            'subcellular_location_checkpoint_file')
        self._membrane_checkpoint_file = self._options.get(
            'membrane_checkpoint_file')
        self._device = get_device(device)

        # check that file paths are passed
        for file in self.necessary_files:
            if file not in self._options:
                raise MissingParameterError(
                    'Please provide subcellular_location_checkpoint_file and membrane_checkpoint_file paths as named '
                    'parameters to the constructor. Mind that these should match the embeddings used, '
                    'e.g.: prottrans_bert_bfd should use la_protbert weights')

        # load pre-trained weights for annotation machines
        subcellular_state = torch.load(
            self._subcellular_location_checkpoint_file,
            map_location=self._device)
        membrane_state = torch.load(self._membrane_checkpoint_file,
                                    map_location=self._device)

        # load pre-trained weights into raw model
        self._subcellular_location_model.load_state_dict(
            subcellular_state['state_dict'])
        self._membrane_model.load_state_dict(membrane_state['state_dict'])

        # ensure that model is in evaluation mode (important for batchnorm and dropout)
        self._subcellular_location_model.eval()
        self._membrane_model.eval()
    def __init__(self,
                 model_type: str,
                 device: Union[None, str, torch.device] = None,
                 **kwargs):
        """
        Initialize annotation extractor. Must define non-positional arguments for paths of files.

        :param secondary_structure_checkpoint_file: path of secondary structure inference model checkpoint file
        :param subcellular_location_checkpoint_file: path of the subcellular location inference model checkpoint file
        """

        self._options = kwargs
        self._model_type = model_type
        self._device = get_device(device)

        # Create un-trained (raw) model and ensure self._model_type is valid
        if self._model_type == "seqvec_from_publication":
            self._subcellular_location_model = SUBCELL_FNN().to(self._device)
        elif self._model_type == "bert_from_publication":  # Drop batchNorm for ProtTrans models
            self._subcellular_location_model = SUBCELL_FNN(
                use_batch_norm=False).to(self._device)
        else:
            print("You first need to define your custom model architecture.")
            raise NotImplementedError

        # Download the checkpoint files if needed
        for file in self.necessary_files:
            if not self._options.get(file):
                self._options[file] = get_model_file(
                    model=f"{self._model_type}_annotations_extractors",
                    file=file)

        self._secondary_structure_checkpoint_file = self._options[
            'secondary_structure_checkpoint_file']
        self._subcellular_location_checkpoint_file = self._options[
            'subcellular_location_checkpoint_file']

        # Read in pre-trained model

        self._secondary_structure_model = SECSTRUCT_CNN().to(self._device)

        # load pre-trained weights for annotation machines
        subcellular_state = torch.load(
            self._subcellular_location_checkpoint_file,
            map_location=self._device)
        secondary_structure_state = torch.load(
            self._secondary_structure_checkpoint_file,
            map_location=self._device)

        # load pre-trained weights into raw model
        self._subcellular_location_model.load_state_dict(
            subcellular_state['state_dict'])
        self._secondary_structure_model.load_state_dict(
            secondary_structure_state['state_dict'])

        # ensure that model is in evaluation mode (important for batchnorm and dropout)
        self._subcellular_location_model.eval()
        self._secondary_structure_model.eval()
Exemplo n.º 8
0
    def __init__(
        self,
        device: Union[None, str, torch.device] = None,
        model_directory: Optional[str] = None,
        half_precision_model: bool = False,
    ):
        """Loads the Bert Model for Masked LM"""
        self.device = get_device(device)
        self._half_precision_model = half_precision_model

        if not model_directory:
            model_directory = get_model_directories_from_zip(
                model=ProtTransBertBFDEmbedder.name,
                directory="model_directory")
        self.tokenizer = BertTokenizer.from_pretrained(model_directory,
                                                       do_lower_case=False)
        self.model = BertForMaskedLM.from_pretrained(model_directory)
        # Compute in half precision, which is a lot faster and saves us half the memory
        if self._half_precision_model:
            self.model = self.model.half()
        self.model = self.model.eval().to(self.device)
Exemplo n.º 9
0
    def __init__(self,
                 device: Union[None, str, torch.device] = None,
                 **kwargs):
        """
        Initialize annotation extractor. Must define non-positional arguments for paths of files.

        :param model_file: path of bindEmbed21DL inference model checkpoint file
        """

        self._options = kwargs
        self._device = get_device(device)

        # Create un-trained (raw) models
        self._binding_residue_model_1 = BindingResiduesCNN().to(self._device)
        self._binding_residue_model_2 = BindingResiduesCNN().to(self._device)
        self._binding_residue_model_3 = BindingResiduesCNN().to(self._device)
        self._binding_residue_model_4 = BindingResiduesCNN().to(self._device)
        self._binding_residue_model_5 = BindingResiduesCNN().to(self._device)

        # Download the checkpoint files if needed
        for file in self.necessary_files:
            if not self._options.get(file):
                self._options[file] = get_model_file(model=f"bindembed21dl",
                                                     file=file)

        self.model_file_1 = self._options['model_1_file']
        self.model_file_2 = self._options['model_2_file']
        self.model_file_3 = self._options['model_3_file']
        self.model_file_4 = self._options['model_4_file']
        self.model_file_5 = self._options['model_5_file']

        # load pre-trained weights for annotation machines
        binding_residue_state_1 = torch.load(self.model_file_1,
                                             map_location=self._device)
        binding_residue_state_2 = torch.load(self.model_file_1,
                                             map_location=self._device)
        binding_residue_state_3 = torch.load(self.model_file_1,
                                             map_location=self._device)
        binding_residue_state_4 = torch.load(self.model_file_1,
                                             map_location=self._device)
        binding_residue_state_5 = torch.load(self.model_file_1,
                                             map_location=self._device)

        # load pre-trained weights into raw model
        self._binding_residue_model_1.load_state_dict(
            binding_residue_state_1['state_dict'])
        self._binding_residue_model_2.load_state_dict(
            binding_residue_state_2['state_dict'])
        self._binding_residue_model_3.load_state_dict(
            binding_residue_state_3['state_dict'])
        self._binding_residue_model_4.load_state_dict(
            binding_residue_state_4['state_dict'])
        self._binding_residue_model_5.load_state_dict(
            binding_residue_state_5['state_dict'])

        # ensure that model is in evaluation mode (important for batchnorm and dropout)
        self._binding_residue_model_1.eval()
        self._binding_residue_model_2.eval()
        self._binding_residue_model_3.eval()
        self._binding_residue_model_4.eval()
        self._binding_residue_model_5.eval()
Exemplo n.º 10
0
def run(**kwargs):
    """BETA: in-silico mutagenesis using BertForMaskedLM

    optional (see extract stage for details):
     * model_directory
     * device
     * half_precision
     * half_precision_model
     * temperature: temperature for softmax
    """
    required_kwargs = [
        "protocol",
        "prefix",
        "stage_name",
        "remapped_sequences_file",
        "mapping_file",
    ]
    check_required(kwargs, required_kwargs)
    result_kwargs = deepcopy(kwargs)
    if result_kwargs["protocol"] not in _PROTOCOLS:
        raise RuntimeError(
            f"Passed protocol {result_kwargs['protocol']}, but allowed are: {', '.join(_PROTOCOLS)}"
        )
    temperature = result_kwargs.setdefault("temperature", 1)
    device = get_device(result_kwargs.get("device"))
    model_class: Type[ProtTransBertBFDMutagenesis] = _PROTOCOLS[
        result_kwargs["protocol"]
    ]
    model = model_class(
        device,
        result_kwargs.get("model_directory"),
        result_kwargs.get("half_precision_model"),
    )

    file_manager = get_file_manager()
    file_manager.create_stage(result_kwargs["prefix"], result_kwargs["stage_name"])

    # The mapping file contains the corresponding ids in the same order
    sequences = [
        str(entry.seq)
        for entry in SeqIO.parse(result_kwargs["remapped_sequences_file"], "fasta")
    ]
    mapping_file = read_mapping_file(result_kwargs["mapping_file"])

    probabilities_all = dict()
    with tqdm(total=int(mapping_file["sequence_length"].sum())) as progress_bar:
        for sequence_id, original_id, sequence in zip(
            mapping_file.index, mapping_file["original_id"], sequences
        ):
            with torch.no_grad():
                probabilities = model.get_sequence_probabilities(
                    sequence, temperature, progress_bar=progress_bar
                )

            for p in probabilities:
                assert math.isclose(
                    1, (sum(p.values()) - p["position"]), rel_tol=1e-6
                ), "softmax values should add up to 1"

            probabilities_all[sequence_id] = probabilities
    residue_probabilities = probabilities_as_dataframe(
        mapping_file, probabilities_all, sequences
    )

    probabilities_file = file_manager.create_file(
        result_kwargs.get("prefix"),
        result_kwargs.get("stage_name"),
        "residue_probabilities_file",
        extension=".csv",
    )
    residue_probabilities.to_csv(probabilities_file, index=False)
    result_kwargs["residue_probabilities_file"] = probabilities_file
    return result_kwargs