Example #1
0
def get_model(name: str, device: Union[None, str, torch.device]) -> Any:
    if name in ["bert_from_publication", "seqvec_from_publication"]:
        return BasicAnnotationExtractor(name, device)
    elif name == "esm1v":
        return name_to_embedder[name](ensemble_id=1, device=device)
    elif name in name_to_embedder:
        return name_to_embedder[name](device=device)
    elif name == "pb_tucker":
        return PBTucker(get_model_file("pb_tucker", "model_file"), device)
    elif name == "deepblast":
        model_file = get_model_file("deepblast", "model_file")
        return LightningAligner.load_from_checkpoint(model_file).to(device)
    else:
        raise ValueError(f"Unknown name {name}")
Example #2
0
    def __init__(self,
                 model_type: str,
                 device: Union[None, str, torch.device] = None,
                 **kwargs):
        """
        Initialize annotation extractor. Must define non-positional arguments for paths of files.

        :param model_file: path of conservation inference model checkpoint file (only CNN-architecture from paper)
        """

        self._options = kwargs
        self._model_type = model_type
        self._device = get_device(device)

        # Create un-trained (raw) model and ensure self._model_type is valid
        self._conservation_model = ConservationCNN().to(self._device)

        # Download the checkpoint files if needed
        if not self._options.get("model_file"):
            self._options["model_file"] = get_model_file(model=f"prott5cons",
                                                         file="model_file")

        self.model_file = self._options["model_file"]

        # load pre-trained weights for annotation machines
        conservation_state = torch.load(self.model_file,
                                        map_location=self._device)

        # load pre-trained weights into raw model
        self._conservation_model.load_state_dict(
            conservation_state["state_dict"])

        # ensure that model is in evaluation mode (important for batchnorm and dropout)
        self._conservation_model.eval()
Example #3
0
def run(**kwargs):
    """
    Run embedding protocol

    Parameters
    ----------
    kwargs arguments (* denotes optional):
        sequences_file: Where sequences live
        prefix: Output prefix for all generated files
        protocol: Which embedder to use
        mapping_file: the mapping file generated by the pipeline when remapping indexes
        stage_name: The stage name

    Returns
    -------
    Dictionary with results of stage
    """
    check_required(
        kwargs,
        [
            "protocol", "prefix", "stage_name", "remapped_sequences_file",
            "mapping_file"
        ],
    )

    if kwargs["protocol"] not in PROTOCOLS:
        raise InvalidParameterError(
            "Invalid protocol selection: {}. Valid protocols are: {}".format(
                kwargs["protocol"], ", ".join(PROTOCOLS.keys())))

    embedder_class = PROTOCOLS[kwargs["protocol"]]

    if embedder_class == UniRepEmbedder and kwargs.get("use_cpu") is not None:
        raise InvalidParameterError(
            "UniRep does not support configuring `use_cpu`")

    result_kwargs = deepcopy(kwargs)

    # Download necessary files if needed
    # noinspection PyProtectedMember
    for file in embedder_class._necessary_files:
        if not result_kwargs.get(file):
            result_kwargs[file] = get_model_file(model=embedder_class.name,
                                                 file=file)

    # noinspection PyProtectedMember
    for directory in embedder_class._necessary_directories:
        if not result_kwargs.get(directory):
            result_kwargs[directory] = get_model_directories_from_zip(
                model=embedder_class.name, directory=directory)

    result_kwargs.setdefault("max_amino_acids",
                             DEFAULT_MAX_AMINO_ACIDS[kwargs["protocol"]])

    file_manager = get_file_manager(**kwargs)
    embedder: EmbedderInterface = embedder_class(**result_kwargs)
    _check_transform_embeddings_function(embedder, result_kwargs)

    return embed_and_write_batched(embedder, file_manager, result_kwargs,
                                   kwargs.get("half_precision", False))
Example #4
0
    def __init__(self, device: Union[None, str, torch.device] = None, **kwargs):

        self._options = kwargs
        self._device = get_device(device)

        # Download the checkpoint files if needed
        for file in self.necessary_files:
            if not self._options.get(file):
                self._options[file] = get_model_file(model='tmbed', file=file)

        self._models = []

        for model_idx in range(5):
            # Create blank model
            model = TmbedModel()
            # Get model file
            model_file = self._options[f'model_{model_idx}_file']
            # Load pre-trained weights
            model.load_state_dict(torch.load(model_file)['model'])
            # Finalize model
            model = model.eval().to(self._device)
            # Add to model list
            self._models.append(model)

        self._decoder = Decoder()
Example #5
0
def pb_tucker(file_manager: FileManagerInterface,
              result_kwargs: Dict[str, Any]) -> Dict[str, Any]:
    device = get_device(result_kwargs.get("device"))

    if "model_file" not in result_kwargs:
        model_file = get_model_file("pb_tucker", "model_file")
    else:
        model_file = result_kwargs["model_file"]
    pb_tucker = PBTucker(model_file, device)

    reduced_embeddings_file_path = result_kwargs["reduced_embeddings_file"]
    projected_reduced_embeddings_file_path = file_manager.create_file(
        result_kwargs.get("prefix"),
        result_kwargs.get("stage_name"),
        "projected_reduced_embeddings_file",
        extension=".csv",
    )
    result_kwargs[
        "projected_reduced_embeddings_file"] = projected_reduced_embeddings_file_path

    with h5py.File(reduced_embeddings_file_path,
                   "r") as input_embeddings, h5py.File(
                       projected_reduced_embeddings_file_path,
                       "w") as output_embeddings:
        for h5_id, reduced_embedding in input_embeddings.items():
            output_embeddings[h5_id] = pb_tucker.project_reduced_embedding(
                reduced_embedding)

    return result_kwargs
Example #6
0
    def __init__(self, model: str, device: Union[None, str, torch.device] = None, **kwargs):
        """
        Initialize annotation extractor. Must define non-positional arguments for paths of files.

        :param membrane_checkpoint_file: path of the membrane boundness inference model checkpoint file
        :param subcellular_location_checkpoint_file: path of the subcellular location inference model checkpoint file
        """

        self._options = kwargs
        self._device = get_device(device)

        # Create un-trained (raw) model
        self._subcellular_location_model = LightAttention(output_dim=10).to(self._device)
        self._membrane_model = LightAttention(output_dim=2).to(self._device)

        self._subcellular_location_checkpoint_file = self._options.get('subcellular_location_checkpoint_file')
        self._membrane_checkpoint_file = self._options.get('membrane_checkpoint_file')
        self._device = get_device(device)

        # Download files if needed
        for file in self.necessary_files:
            if not self._options.get(file):
                self._options[file] = get_model_file(model=model, file=file)

        # load pre-trained weights for annotation machines
        subcellular_state = torch.load(self._subcellular_location_checkpoint_file, map_location=self._device)
        membrane_state = torch.load(self._membrane_checkpoint_file, map_location=self._device)

        # load pre-trained weights into raw model
        self._subcellular_location_model.load_state_dict(subcellular_state['state_dict'])
        self._membrane_model.load_state_dict(membrane_state['state_dict'])

        # ensure that model is in evaluation mode (important for batchnorm and dropout)
        self._subcellular_location_model.eval()
        self._membrane_model.eval()
Example #7
0
    def __init__(self, device: Union[None, str, torch.device] = None, **kwargs):
        """
        Initializer accepts location of a pre-trained model and options
        """
        self._options = kwargs
        self._device = get_device(device)

        # Special case because SeqVec can currently be used with either a model directory or two files
        if self.__class__.__name__ == "SeqVecEmbedder":
            # No need to download weights_file/options_file if model_directory is given
            if "model_directory" in self._options:
                return

        files_loaded = 0
        for file in self.necessary_files:
            if not self._options.get(file):
                self._options[file] = get_model_file(model=self.name, file=file)
                files_loaded += 1

        for directory in self.necessary_directories:
            if not self._options.get(directory):
                self._options[directory] = get_model_directories_from_zip(
                    model=self.name, directory=directory
                )

                files_loaded += 1

        total_necessary = len(self.necessary_files) + len(self.necessary_directories)
        if 0 < files_loaded < total_necessary:
            logger.warning(
                f"You should pass either all necessary files or directories, or none, "
                f"while you provide {files_loaded} of {total_necessary}"
            )
    def __init__(self,
                 model_type: str,
                 device: Union[None, str, torch.device] = None,
                 **kwargs):
        """
        Initialize annotation extractor. Must define non-positional arguments for paths of files.

        :param secondary_structure_checkpoint_file: path of secondary structure inference model checkpoint file
        :param subcellular_location_checkpoint_file: path of the subcellular location inference model checkpoint file
        """

        self._options = kwargs
        self._model_type = model_type
        self._device = get_device(device)

        # Create un-trained (raw) model and ensure self._model_type is valid
        if self._model_type == "seqvec_from_publication":
            self._subcellular_location_model = SUBCELL_FNN().to(self._device)
        elif self._model_type == "bert_from_publication":  # Drop batchNorm for ProtTrans models
            self._subcellular_location_model = SUBCELL_FNN(
                use_batch_norm=False).to(self._device)
        else:
            print("You first need to define your custom model architecture.")
            raise NotImplementedError

        # Download the checkpoint files if needed
        for file in self.necessary_files:
            if not self._options.get(file):
                self._options[file] = get_model_file(
                    model=f"{self._model_type}_annotations_extractors",
                    file=file)

        self._secondary_structure_checkpoint_file = self._options[
            'secondary_structure_checkpoint_file']
        self._subcellular_location_checkpoint_file = self._options[
            'subcellular_location_checkpoint_file']

        # Read in pre-trained model

        self._secondary_structure_model = SECSTRUCT_CNN().to(self._device)

        # load pre-trained weights for annotation machines
        subcellular_state = torch.load(
            self._subcellular_location_checkpoint_file,
            map_location=self._device)
        secondary_structure_state = torch.load(
            self._secondary_structure_checkpoint_file,
            map_location=self._device)

        # load pre-trained weights into raw model
        self._subcellular_location_model.load_state_dict(
            subcellular_state['state_dict'])
        self._secondary_structure_model.load_state_dict(
            secondary_structure_state['state_dict'])

        # ensure that model is in evaluation mode (important for batchnorm and dropout)
        self._subcellular_location_model.eval()
        self._secondary_structure_model.eval()
Example #9
0
def tmbed(**kwargs) -> Dict[str, Any]:
    '''
    Protocol extracts membrane residues from "embeddings_file".
    Embeddings must have been generated with ProtT5-XL-U50.
    '''

    check_required(kwargs, ['embeddings_file', 'remapped_sequences_file'])

    result_kwargs = deepcopy(kwargs)
    file_manager = get_file_manager(**kwargs)

    # Download necessary files if needed
    for file in TmbedAnnotationExtractor.necessary_files:
        if not result_kwargs.get(file):
            result_kwargs[file] = get_model_file(model='tmbed', file=file)

    tmbed_extractor = TmbedAnnotationExtractor(**result_kwargs)

    # Try to create final file (if this fails, now is better than later)
    membrane_residues_predictions_file_path = file_manager.create_file(result_kwargs.get('prefix'),
                                                                       result_kwargs.get('stage_name'),
                                                                       'membrane_residues_predictions_file',
                                                                       extension='.fasta')

    result_kwargs['membrane_residues_predictions_file'] = membrane_residues_predictions_file_path

    tmbed_sequences = []

    with h5py.File(result_kwargs['embeddings_file'], 'r') as embedding_file:
        for protein_sequence in read_fasta(result_kwargs['remapped_sequences_file']):
            embedding = np.array(embedding_file[protein_sequence.id])

            # Add batch dimension (until we support batch processing)
            embedding = embedding[None, ]

            # Sequence lengths (only a single sequence for now)
            lengths = [len(protein_sequence.seq)]

            annotations = tmbed_extractor.get_membrane_residues(embedding, lengths)

            # Gratuitous loop (only a single item for now)
            # Needs to be changed for batch mode to deepcopy different protein sequences
            for annotation in annotations:
                tmbed_sequence = deepcopy(protein_sequence)
                tmbed_sequence.seq = Seq(convert_list_of_enum_to_string(annotation.membrane_residues))

                tmbed_sequences.append(tmbed_sequence)

    # Write file
    write_fasta_file(tmbed_sequences, membrane_residues_predictions_file_path)

    return result_kwargs
Example #10
0
    def __init__(self,
                 ensemble_id: int,
                 device: Union[None, str, torch.device] = None,
                 **kwargs):
        """You must pass the number of the model (1-5) as first parameter, though you can override the weights file with
        model_file"""
        assert ensemble_id in range(1, 6), "The model number must be in 1-5"
        self.ensemble_id = ensemble_id

        # EmbedderInterface assumes static model files, but we need to dynamically select one of the five
        if "model_file" not in kwargs:
            kwargs["model_file"] = get_model_file(
                model=self.name, file=f"model_{ensemble_id}_file")

        super().__init__(device, **kwargs)
Example #11
0
def test_tucker(pytestconfig, device):
    bert_embeddings_file = pytestconfig.rootpath.joinpath(
        "test-data/reference-embeddings").joinpath(
            ProtTransBertBFDEmbedder.name + ".npz")
    bert_embeddings = numpy.load(bert_embeddings_file)
    tucker_embeddings_file = pytestconfig.rootpath.joinpath(
        "test-data/reference-embeddings").joinpath(PBTucker.name + ".npz")
    tucker_embeddings = numpy.load(tucker_embeddings_file)

    pb_tucker = PBTucker(get_model_file("pb_tucker", "model_file"), device)

    for name, embedding in bert_embeddings.items():
        reduced_embedding = embedding.mean(axis=0)
        tucker_embedding = pb_tucker.project_reduced_embedding(
            reduced_embedding)
        assert numpy.allclose(tucker_embeddings[name],
                              tucker_embedding,
                              rtol=1.0e-3,
                              atol=1.0e-5), name
Example #12
0
def prott5cons(model: str, **kwargs) -> Dict[str, Any]:
    """
    Protocol extracts conservation from "embeddings_file".
    Embeddings can only be generated with ProtT5-XL-U50.

    :param model: "t5_xl_u50_conservation". Used to download files
    """

    check_required(kwargs, ['embeddings_file', 'mapping_file', 'remapped_sequences_file'])
    result_kwargs = deepcopy(kwargs)
    file_manager = get_file_manager(**kwargs)

    # Download necessary files if needed
    for file in ProtT5consAnnotationExtractor.necessary_files:
        if not result_kwargs.get(file):
            result_kwargs[file] = get_model_file(model=model, file=file)

    annotation_extractor = ProtT5consAnnotationExtractor(**result_kwargs)

    # mapping file will be needed for protein-wide annotations
    mapping_file = read_mapping_file(result_kwargs["mapping_file"])

    # Try to create final files (if this fails, now is better than later
    conservation_predictions_file_path = file_manager.create_file(result_kwargs.get('prefix'),
                                                                  result_kwargs.get('stage_name'),
                                                                  'conservation_predictions_file',
                                                                  extension='.fasta')
    result_kwargs['conservation_predictions_file'] = conservation_predictions_file_path
    cons_sequences = list()
    with h5py.File(result_kwargs['embeddings_file'], 'r') as embedding_file:
        for protein_sequence in read_fasta(result_kwargs['remapped_sequences_file']):
            embedding = np.array(embedding_file[protein_sequence.id])

            annotations = annotation_extractor.get_conservation(embedding)
            cons_sequence = deepcopy(protein_sequence)
            cons_sequence.seq = Seq(convert_list_of_enum_to_string(annotations.conservation))
            cons_sequences.append(cons_sequence)

    # Write files
    write_fasta_file(cons_sequences, conservation_predictions_file_path)
    return result_kwargs
Example #13
0
def run(**kwargs):
    """
    Run embedding protocol

    Parameters
    ----------
    kwargs arguments (* denotes optional):
        sequences_file: Where sequences live
        prefix: Output prefix for all generated files
        protocol: Which embedder to use
        mapping_file: the mapping file generated by the pipeline when remapping indexes
        stage_name: The stage name

    Returns
    -------
    Dictionary with results of stage
    """
    embedder_class, result_kwargs = prepare_kwargs(**kwargs)

    # Download necessary files if needed
    # noinspection PyProtectedMember
    for file in embedder_class.necessary_files:
        if not result_kwargs.get(file):
            result_kwargs[file] = get_model_file(model=embedder_class.name,
                                                 file=file)

    # noinspection PyProtectedMember
    for directory in embedder_class.necessary_directories:
        if not result_kwargs.get(directory):
            result_kwargs[directory] = get_model_directories_from_zip(
                model=embedder_class.name, directory=directory)

    file_manager = get_file_manager(**kwargs)
    embedder: EmbedderInterface = embedder_class(**result_kwargs)
    _check_transform_embeddings_function(embedder, result_kwargs)

    return embed_and_write_batched(embedder, file_manager, result_kwargs,
                                   kwargs.get("half_precision", False))
Example #14
0
    def __init__(self,
                 device: Union[None, str, torch.device] = None,
                 **kwargs):
        """
        Initialize annotation extractor. Must define non-positional arguments for paths of files.

        :param model_file: path of bindEmbed21DL inference model checkpoint file
        """

        self._options = kwargs
        self._device = get_device(device)

        # Create un-trained (raw) models
        self._binding_residue_model_1 = BindingResiduesCNN().to(self._device)
        self._binding_residue_model_2 = BindingResiduesCNN().to(self._device)
        self._binding_residue_model_3 = BindingResiduesCNN().to(self._device)
        self._binding_residue_model_4 = BindingResiduesCNN().to(self._device)
        self._binding_residue_model_5 = BindingResiduesCNN().to(self._device)

        # Download the checkpoint files if needed
        for file in self.necessary_files:
            if not self._options.get(file):
                self._options[file] = get_model_file(model=f"bindembed21dl",
                                                     file=file)

        self.model_file_1 = self._options['model_1_file']
        self.model_file_2 = self._options['model_2_file']
        self.model_file_3 = self._options['model_3_file']
        self.model_file_4 = self._options['model_4_file']
        self.model_file_5 = self._options['model_5_file']

        # load pre-trained weights for annotation machines
        binding_residue_state_1 = torch.load(self.model_file_1,
                                             map_location=self._device)
        binding_residue_state_2 = torch.load(self.model_file_1,
                                             map_location=self._device)
        binding_residue_state_3 = torch.load(self.model_file_1,
                                             map_location=self._device)
        binding_residue_state_4 = torch.load(self.model_file_1,
                                             map_location=self._device)
        binding_residue_state_5 = torch.load(self.model_file_1,
                                             map_location=self._device)

        # load pre-trained weights into raw model
        self._binding_residue_model_1.load_state_dict(
            binding_residue_state_1['state_dict'])
        self._binding_residue_model_2.load_state_dict(
            binding_residue_state_2['state_dict'])
        self._binding_residue_model_3.load_state_dict(
            binding_residue_state_3['state_dict'])
        self._binding_residue_model_4.load_state_dict(
            binding_residue_state_4['state_dict'])
        self._binding_residue_model_5.load_state_dict(
            binding_residue_state_5['state_dict'])

        # ensure that model is in evaluation mode (important for batchnorm and dropout)
        self._binding_residue_model_1.eval()
        self._binding_residue_model_2.eval()
        self._binding_residue_model_3.eval()
        self._binding_residue_model_4.eval()
        self._binding_residue_model_5.eval()
Example #15
0
def deepblast(**kwargs) -> Dict[str, Any]:
    """Sequence-Sequence alignments with DeepBLAST

    DeepBLAST learned structural alignments from sequence

    https://github.com/flatironinstitute/deepblast

    https://www.biorxiv.org/content/10.1101/2020.11.03.365932v1
    """
    # TODO: Fix that logic before merging
    if "transferred_annotations_file" not in kwargs and "pairings_file" not in kwargs:
        raise MissingParameterError(
            "You need to specify either 'transferred_annotations_file' or 'pairings_file' for DeepBLAST"
        )
    if "transferred_annotations_file" in kwargs and "pairings_file" in kwargs:
        raise InvalidParameterError(
            "You can't specify both 'transferred_annotations_file' and 'pairings_file' for DeepBLAST"
        )
    result_kwargs = deepcopy(kwargs)
    file_manager = get_file_manager(**kwargs)

    # This stays below 8GB, so it should be a good default
    batch_size = result_kwargs.setdefault("batch_size", 50)

    if "device" in result_kwargs:
        device = torch.device(result_kwargs["device"])
        if device.type != "cuda":
            raise RuntimeError(
                f"You can only run DeepBLAST on a CUDA-compatible GPU, not on {device.type}"
            )
    else:
        if not torch.cuda.is_available():
            raise RuntimeError(
                "DeepBLAST requires a CUDA-compatible GPU, but none was found")
        device = torch.device("cuda")

    mapping_file = read_mapping_file(result_kwargs["mapping_file"])
    mapping = {
        str(remapped): original
        for remapped, original in mapping_file[["original_id"]].itertuples()
    }

    query_by_id = {
        mapping[entry.name]: str(entry.seq)
        for entry in SeqIO.parse(result_kwargs["remapped_sequences_file"],
                                 "fasta")
    }

    # You can either provide a set of pairing or use the output of k-nn with a fasta file for the reference embeddings
    if "pairings_file" in result_kwargs:
        pairings_file = read_csv(result_kwargs["pairings_file"])
        pairings = list(pairings_file[["query",
                                       "target"]].itertuples(index=False))
        target_by_id = query_by_id
    else:
        transferred_annotations_file = read_csv(
            result_kwargs["transferred_annotations_file"])
        pairings = []
        for _, row in transferred_annotations_file.iterrows():
            query = row["original_id"]
            for target in row.filter(regex="k_nn_.*_identifier"):
                pairings.append((query, target))

        target_by_id = {}
        for entry in read_fasta(result_kwargs["reference_fasta_file"]):
            target_by_id[entry.name] = str(entry.seq[:])

    # Create one output file per query
    result_kwargs["alignment_files"] = dict()
    for query in set(i for i, _ in pairings):
        filename = file_manager.create_file(
            result_kwargs.get("prefix"),
            result_kwargs.get("stage_name"),
            f"{slugify(query, lowercase=False)}_alignments",
            extension=".a2m",
        )
        result_kwargs["alignment_files"][query] = filename

    unknown_queries = set(list(zip(*pairings))[0]) - set(query_by_id.keys())
    if unknown_queries:
        raise ValueError(f"Unknown query sequences: {unknown_queries}")

    unknown_targets = set(list(zip(*pairings))[1]) - set(target_by_id.keys())
    if unknown_targets:
        raise ValueError(f"Unknown target sequences: {unknown_targets}")

    # Load the pretrained model
    if "model_file" not in result_kwargs:
        model_file = get_model_file("deepblast", "model_file")
    else:
        model_file = result_kwargs["model_file"]

    alignments = deepblast_align(pairings, query_by_id, target_by_id,
                                 model_file, device, batch_size)

    for query, alignments in itertools.groupby(alignments, key=lambda i: i[0]):
        _, targets, queries_aligned, targets_aligned = list(zip(*alignments))
        padded_query, padded_targets = pairwise_alignments_to_msa(
            queries_aligned, targets_aligned)
        with open(result_kwargs["alignment_files"][query], "w") as fp:
            fp.write(f">{query}\n")
            fp.write(f"{padded_query}\n")
            for target, padded_target in zip(targets, padded_targets):
                fp.write(f">{target}\n")
                fp.write(f"{padded_target}\n")

    return result_kwargs
Example #16
0
        ),
        (
            lambda: ProtTransBertBFDEmbedder(),
            lambda: BasicAnnotationExtractor("bert_from_publication"),
            0.5387755102040817,
        ),
        (
            lambda: ProtTransT5XLU50Embedder(half_precision_model=True),
            lambda: BasicAnnotationExtractor("t5_xl_u50_from_publication"),
            0.6285714285714286,
        ),
        (
            lambda: ProtTransT5XLU50Embedder(half_precision_model=True),
            lambda: LightAttentionAnnotationExtractor(
                subcellular_location_checkpoint_file=get_model_file(
                    "la_prott5", "subcellular_location_checkpoint_file"
                ),
                membrane_checkpoint_file=get_model_file(
                    "la_prott5", "membrane_checkpoint_file"
                ),
            ),
            0.6551020408163265,
        ),
    ],
)
def test_basic_annotation_extractor(
    pytestconfig,
    get_embedder: Callable[[], EmbedderInterface],
    get_extractor: Callable[
        [], Union[BasicAnnotationExtractor, LightAttentionAnnotationExtractor]
    ],
Example #17
0
def bindembed21dl(**kwargs) -> Dict[str, Any]:
    """
    Protocol extracts binding residues from "embeddings_file".
    Results guaranteed only with ProtT5-XL-U50 embeddings.

    :return:
    """

    check_required(kwargs, ['embeddings_file', 'mapping_file', 'remapped_sequences_file'])
    result_kwargs = deepcopy(kwargs)
    file_manager = get_file_manager(**kwargs)

    # Download necessary files if needed
    for file in BindEmbed21DLAnnotationExtractor.necessary_files:
        if not result_kwargs.get(file):
            result_kwargs[file] = get_model_file(model="bindembed21dl", file=file)

    annotation_extractor = BindEmbed21DLAnnotationExtractor(**result_kwargs)

    # Try to create final files (if this fails, now is better than later
    metal_binding_predictions_file_path = file_manager.create_file(result_kwargs.get('prefix'),
                                                                   result_kwargs.get('stage_name'),
                                                                   'metal_binding_predictions_file',
                                                                   extension='.fasta')
    result_kwargs['metal_binding_predictions_file'] = metal_binding_predictions_file_path
    nuc_binding_predictions_file_path = file_manager.create_file(result_kwargs.get('prefix'),
                                                                 result_kwargs.get('stage_name'),
                                                                 'nucleic_acid_binding_predictions_file',
                                                                 extension='.fasta')
    result_kwargs['binding_residue_predictions_file'] = nuc_binding_predictions_file_path
    small_binding_predictions_file_path = file_manager.create_file(result_kwargs.get('prefix'),
                                                                   result_kwargs.get('stage_name'),
                                                                   'small_molecule_binding_predictions_file',
                                                                   extension='.fasta')
    result_kwargs['binding_residue_predictions_file'] = small_binding_predictions_file_path

    metal_sequences = list()
    nuc_sequences = list()
    small_sequences = list()

    with h5py.File(result_kwargs['embeddings_file'], 'r') as embedding_file:
        for protein_sequence in read_fasta(result_kwargs['remapped_sequences_file']):
            embedding = np.array(embedding_file[protein_sequence.id])

            annotations = annotation_extractor.get_binding_residues(embedding)
            metal_sequence = deepcopy(protein_sequence)
            nuc_sequence = deepcopy(protein_sequence)
            small_sequence = deepcopy(protein_sequence)

            metal_sequence.seq = Seq(convert_list_of_enum_to_string(annotations.metal_ion))
            nuc_sequence.seq = Seq(convert_list_of_enum_to_string(annotations.nucleic_acids))
            small_sequence.seq = Seq(convert_list_of_enum_to_string(annotations.small_molecules))

            metal_sequences.append(metal_sequence)
            nuc_sequences.append(nuc_sequence)
            small_sequences.append(small_sequence)

    # Write files
    write_fasta_file(metal_sequences, metal_binding_predictions_file_path)
    write_fasta_file(nuc_sequences, nuc_binding_predictions_file_path)
    write_fasta_file(small_sequences, small_binding_predictions_file_path)

    return result_kwargs
Example #18
0
def bindembed21(**kwargs) -> Dict[str, Any]:
    """
    Protocol extracts binding residues from "alignment_result_file" if possible, and from "embeddings_file", otherwise.
    :param kwargs:
    :return:
    """

    check_required(kwargs, ['alignment_results_file', 'embeddings_file', 'mapping_file', 'remapped_sequences_file'])
    result_kwargs = deepcopy(kwargs)
    file_manager = get_file_manager(**kwargs)

    # Download necessary files if needed
    # for HBI
    for directory in BindEmbed21HBIAnnotationExtractor.necessary_directories:
        if not result_kwargs.get(directory):
            result_kwargs[directory] = get_model_directories_from_zip(model="bindembed21hbi", directory=directory)
    # for DL
    for file in BindEmbed21DLAnnotationExtractor.necessary_files:
        if not result_kwargs.get(file):
            result_kwargs[file] = get_model_file(model="bindembed21dl", file=file)

    hbi_extractor = BindEmbed21HBIAnnotationExtractor(**result_kwargs)
    dl_extractor = BindEmbed21DLAnnotationExtractor(**result_kwargs)

    # Try to create final files (if this fails, now is better than later
    metal_binding_predictions_file_path = file_manager.create_file(result_kwargs.get('prefix'),
                                                                   result_kwargs.get('stage_name'),
                                                                   'metal_binding_predictions_file',
                                                                   extension='.fasta')
    result_kwargs['metal_binding_predictions_file'] = metal_binding_predictions_file_path
    nuc_binding_predictions_file_path = file_manager.create_file(result_kwargs.get('prefix'),
                                                                 result_kwargs.get('stage_name'),
                                                                 'nucleic_acid_binding_predictions_file',
                                                                 extension='.fasta')
    result_kwargs['binding_residue_predictions_file'] = nuc_binding_predictions_file_path
    small_binding_predictions_file_path = file_manager.create_file(result_kwargs.get('prefix'),
                                                                   result_kwargs.get('stage_name'),
                                                                   'small_molecule_binding_predictions_file',
                                                                   extension='.fasta')
    result_kwargs['binding_residue_predictions_file'] = small_binding_predictions_file_path

    metal_sequences = list()
    nuc_sequences = list()
    small_sequences = list()

    alignment_results = read_csv(result_kwargs['alignment_results_file'], sep='\t',
                                 dtype={'query': 'str', 'target': 'str'})
    alignment_results = alignment_results[alignment_results['eval'] < 1E-3].copy()

    with h5py.File(result_kwargs['embeddings_file'], 'r') as embedding_file:
        for protein_sequence in read_fasta(result_kwargs['remapped_sequences_file']):
            # get HBI hit for this query
            hits = alignment_results[alignment_results['query'].str.match(str(protein_sequence.id))].copy()
            hits_min_eval = hits[hits['eval'] == min(hits['eval'])]
            hit_max_pide = hits_min_eval[hits_min_eval['fident'] == max(hits_min_eval['fident'])]

            metal_sequence = deepcopy(protein_sequence)
            nuc_sequence = deepcopy(protein_sequence)
            small_sequence = deepcopy(protein_sequence)

            hbi_annotations = hbi_extractor.get_binding_residues(hit_max_pide.iloc[0].to_dict())
            metal_inference = convert_list_of_enum_to_string(hbi_annotations.metal_ion)
            nuc_inference = convert_list_of_enum_to_string(hbi_annotations.nucleic_acids)
            small_inference = convert_list_of_enum_to_string(hbi_annotations.small_molecules)

            # some part of the sequence was predicted using HBI --> save output and don't run DL method
            if 'M' in metal_inference or 'N' in nuc_inference or 'S' in small_inference:
                metal_sequence.seq = Seq(metal_inference)
                nuc_sequence.seq = Seq(nuc_inference)
                small_sequence.seq = Seq(small_inference)
            # no inference containing binding annotations was made --> run bindEmbed21DL
            else:
                embedding = np.array(embedding_file[protein_sequence.id])
                annotations = dl_extractor.get_binding_residues(embedding)
                metal_sequence = deepcopy(protein_sequence)
                nuc_sequence = deepcopy(protein_sequence)
                small_sequence = deepcopy(protein_sequence)

                metal_sequence.seq = Seq(convert_list_of_enum_to_string(annotations.metal_ion))
                nuc_sequence.seq = Seq(convert_list_of_enum_to_string(annotations.nucleic_acids))
                small_sequence.seq = Seq(convert_list_of_enum_to_string(annotations.small_molecules))

            metal_sequences.append(metal_sequence)
            nuc_sequences.append(nuc_sequence)
            small_sequences.append(small_sequence)

    # Write files
    write_fasta_file(metal_sequences, metal_binding_predictions_file_path)
    write_fasta_file(nuc_sequences, nuc_binding_predictions_file_path)
    write_fasta_file(small_sequences, small_binding_predictions_file_path)

    return result_kwargs