コード例 #1
0
def test_basic_annotation_extractor(
    pytestconfig,
    get_embedder: Callable[[], EmbedderInterface],
    get_extractor: Callable[
        [], Union[BasicAnnotationExtractor, LightAttentionAnnotationExtractor]
    ],
    expected_accuracy: float,
):
    """Check that BasicAnnotationExtractor passes (without checking correctness)"""
    extractor = get_extractor()
    embedder = get_embedder()

    results = []
    new_hard_set = pytestconfig.rootpath.joinpath(
        "test-data/subcellular_location_new_hard_set.fasta"
    )
    for record in tqdm(read_fasta(str(new_hard_set))):
        embedding = embedder.embed(str(record.seq[:]))
        localization = extractor.get_subcellular_location(embedding)
        expected_localization = normalize_location(
            record.description.split(" ")[1][:-2]
        )
        actual_localization = normalize_location(str(localization.localization))
        results.append(expected_localization == actual_localization)

    actual_accuracy = numpy.asarray(results).mean()
    assert actual_accuracy == pytest.approx(expected_accuracy)
コード例 #2
0
def _process_fasta_file(**kwargs):
    """
    Will assign MD5 hash as ID if no if provided for a sequence.
    """
    result_kwargs = deepcopy(kwargs)
    file_manager = get_file_manager(**kwargs)

    sequences = read_fasta(kwargs['sequences_file'])
    sequences_file_path = file_manager.create_file(kwargs.get('prefix'),
                                                   None,
                                                   'sequences_file',
                                                   extension='.fasta')
    write_fasta_file(sequences, sequences_file_path)

    result_kwargs['sequences_file'] = sequences_file_path

    # Remap using sequence position rather than md5 hash -- not encouraged!
    result_kwargs['simple_remapping'] = result_kwargs.get(
        'simple_remapping', False)

    mapping = reindex_sequences(sequences,
                                simple=result_kwargs['simple_remapping'])

    # Check if there's the same MD5 index twice. This most likely indicates 100% sequence identity.
    # Throw an error for MD5 hash clashes!
    if mapping.index.has_duplicates:
        raise MD5ClashException(
            "There is at least one MD5 hash clash.\n"
            "This most likely indicates there are multiple identical sequences in your FASTA file.\n"
            "MD5 hashes are used to remap sequence identifiers from the input FASTA.\n"
            "This error exists to prevent wasting resources (computing the same embedding twice).\n"
            "There's a (very) low probability of this indicating a real MD5 clash.\n\n"
            "If you are sure there are no identical sequences in your set, please open an issue at "
            "https://github.com/sacdallago/bio_embeddings/issues . "
            "Otherwise, use cd-hit to reduce your input FASTA to exclude identical sequences!"
        )

    mapping_file_path = file_manager.create_file(kwargs.get('prefix'),
                                                 None,
                                                 'mapping_file',
                                                 extension='.csv')
    remapped_sequence_file_path = file_manager.create_file(
        kwargs.get('prefix'),
        None,
        'remapped_sequences_file',
        extension='.fasta')

    write_fasta_file(sequences, remapped_sequence_file_path)
    mapping.to_csv(mapping_file_path)

    result_kwargs['mapping_file'] = mapping_file_path
    result_kwargs['remapped_sequences_file'] = remapped_sequence_file_path

    return result_kwargs
コード例 #3
0
def test_batching_t5(pytestconfig):
    """Check that T5 batching is still failing"""
    embedder = ProtTransT5BFDEmbedder()
    fasta_file = pytestconfig.rootpath.joinpath("examples/docker/fasta.fa")
    batch = [str(i.seq[:]) for i in read_fasta(str(fasta_file))]
    embeddings_single_sequence = list(
        super(ProtTransT5Embedder, embedder).embed_many(batch,
                                                        batch_size=None))
    embeddings_batched = list(
        super(ProtTransT5Embedder, embedder).embed_many(batch,
                                                        batch_size=10000))
    for a, b in zip(embeddings_single_sequence, embeddings_batched):
        assert not numpy.allclose(a, b) and numpy.allclose(
            a, b, rtol=1.0e-4, atol=1.0e-5)
コード例 #4
0
ファイル: pipeline.py プロジェクト: bizzmug/bio_embeddings
def _process_fasta_file(**kwargs):
    """
    Will assign MD5 hash as ID if no if provided for a sequence.
    """
    result_kwargs = deepcopy(kwargs)
    file_manager = get_file_manager(**kwargs)

    sequences = read_fasta(kwargs['sequences_file'])

    # Sanity check the fasta file to avoid nonsense and/or crashes by the embedders
    letters = set(string.ascii_letters)
    for entry in sequences:
        illegal = sorted(set(entry.seq) - letters)
        if illegal:
            formatted = "'" + "', '".join(illegal) + "'"
            raise ValueError(
                f"The entry '{entry.name}' in {kwargs['sequences_file']} contains the characters {formatted}, "
                f"while only single letter code is allowed "
                f"(https://en.wikipedia.org/wiki/Amino_acid#Table_of_standard_amino_acid_abbreviations_and_properties)."
            )
        # This is a warning due to the inconsistent handling between different embedders
        if not str(entry.seq).isupper():
            logger.warning(
                f"The entry '{entry.name}' in {kwargs['sequences_file']} contains lower case amino acids. "
                f"Lower case letters are uninterpretable by most language models, "
                f"and their embedding will be nonesensical. "
                f"Protein LMs available through bio_embeddings have been trained on upper case, "
                f"single letter code sequence representations only "
                f"(https://en.wikipedia.org/wiki/Amino_acid#Table_of_standard_amino_acid_abbreviations_and_properties)."
            )

    sequences_file_path = file_manager.create_file(kwargs.get('prefix'),
                                                   None,
                                                   'sequences_file',
                                                   extension='.fasta')
    write_fasta_file(sequences, sequences_file_path)

    result_kwargs['sequences_file'] = sequences_file_path

    # Remap using sequence position rather than md5 hash -- not encouraged!
    result_kwargs['simple_remapping'] = result_kwargs.get(
        'simple_remapping', False)

    mapping = reindex_sequences(sequences,
                                simple=result_kwargs['simple_remapping'])

    # Check if there's the same MD5 index twice. This most likely indicates 100% sequence identity.
    # Throw an error for MD5 hash clashes!
    if mapping.index.has_duplicates:
        raise MD5ClashException(
            "There is at least one MD5 hash clash.\n"
            "This most likely indicates there are multiple identical sequences in your FASTA file.\n"
            "MD5 hashes are used to remap sequence identifiers from the input FASTA.\n"
            "This error exists to prevent wasting resources (computing the same embedding twice).\n"
            "There's a (very) low probability of this indicating a real MD5 clash.\n\n"
            "If you are sure there are no identical sequences in your set, please open an issue at "
            "https://github.com/sacdallago/bio_embeddings/issues . "
            "Otherwise, use cd-hit to reduce your input FASTA to exclude identical sequences!"
        )

    mapping_file_path = file_manager.create_file(kwargs.get('prefix'),
                                                 None,
                                                 'mapping_file',
                                                 extension='.csv')
    remapped_sequence_file_path = file_manager.create_file(
        kwargs.get('prefix'),
        None,
        'remapped_sequences_file',
        extension='.fasta')

    write_fasta_file(sequences, remapped_sequence_file_path)
    mapping.to_csv(mapping_file_path)

    result_kwargs['mapping_file'] = mapping_file_path
    result_kwargs['remapped_sequences_file'] = remapped_sequence_file_path

    return result_kwargs
コード例 #5
0
def deepblast(**kwargs) -> Dict[str, Any]:
    """Sequence-Sequence alignments with DeepBLAST

    DeepBLAST learned structural alignments from sequence

    https://github.com/flatironinstitute/deepblast

    https://www.biorxiv.org/content/10.1101/2020.11.03.365932v1
    """
    # TODO: Fix that logic before merging
    if "transferred_annotations_file" not in kwargs and "pairings_file" not in kwargs:
        raise MissingParameterError(
            "You need to specify either 'transferred_annotations_file' or 'pairings_file' for DeepBLAST"
        )
    if "transferred_annotations_file" in kwargs and "pairings_file" in kwargs:
        raise InvalidParameterError(
            "You can't specify both 'transferred_annotations_file' and 'pairings_file' for DeepBLAST"
        )
    result_kwargs = deepcopy(kwargs)
    file_manager = get_file_manager(**kwargs)

    # This stays below 8GB, so it should be a good default
    batch_size = result_kwargs.setdefault("batch_size", 50)

    if "device" in result_kwargs:
        device = torch.device(result_kwargs["device"])
        if device.type != "cuda":
            raise RuntimeError(
                f"You can only run DeepBLAST on a CUDA-compatible GPU, not on {device.type}"
            )
    else:
        if not torch.cuda.is_available():
            raise RuntimeError(
                "DeepBLAST requires a CUDA-compatible GPU, but none was found")
        device = torch.device("cuda")

    mapping_file = read_mapping_file(result_kwargs["mapping_file"])
    mapping = {
        str(remapped): original
        for remapped, original in mapping_file[["original_id"]].itertuples()
    }

    query_by_id = {
        mapping[entry.name]: str(entry.seq)
        for entry in SeqIO.parse(result_kwargs["remapped_sequences_file"],
                                 "fasta")
    }

    # You can either provide a set of pairing or use the output of k-nn with a fasta file for the reference embeddings
    if "pairings_file" in result_kwargs:
        pairings_file = read_csv(result_kwargs["pairings_file"])
        pairings = list(pairings_file[["query",
                                       "target"]].itertuples(index=False))
        target_by_id = query_by_id
    else:
        transferred_annotations_file = read_csv(
            result_kwargs["transferred_annotations_file"])
        pairings = []
        for _, row in transferred_annotations_file.iterrows():
            query = row["original_id"]
            for target in row.filter(regex="k_nn_.*_identifier"):
                pairings.append((query, target))

        target_by_id = {}
        for entry in read_fasta(result_kwargs["reference_fasta_file"]):
            target_by_id[entry.name] = str(entry.seq[:])

    # Create one output file per query
    result_kwargs["alignment_files"] = dict()
    for query in set(i for i, _ in pairings):
        filename = file_manager.create_file(
            result_kwargs.get("prefix"),
            result_kwargs.get("stage_name"),
            f"{slugify(query, lowercase=False)}_alignments",
            extension=".a2m",
        )
        result_kwargs["alignment_files"][query] = filename

    unknown_queries = set(list(zip(*pairings))[0]) - set(query_by_id.keys())
    if unknown_queries:
        raise ValueError(f"Unknown query sequences: {unknown_queries}")

    unknown_targets = set(list(zip(*pairings))[1]) - set(target_by_id.keys())
    if unknown_targets:
        raise ValueError(f"Unknown target sequences: {unknown_targets}")

    # Load the pretrained model
    if "model_file" not in result_kwargs:
        model_file = get_model_file("deepblast", "model_file")
    else:
        model_file = result_kwargs["model_file"]

    alignments = deepblast_align(pairings, query_by_id, target_by_id,
                                 model_file, device, batch_size)

    for query, alignments in itertools.groupby(alignments, key=lambda i: i[0]):
        _, targets, queries_aligned, targets_aligned = list(zip(*alignments))
        padded_query, padded_targets = pairwise_alignments_to_msa(
            queries_aligned, targets_aligned)
        with open(result_kwargs["alignment_files"][query], "w") as fp:
            fp.write(f">{query}\n")
            fp.write(f"{padded_query}\n")
            for target, padded_target in zip(targets, padded_targets):
                fp.write(f">{target}\n")
                fp.write(f"{padded_target}\n")

    return result_kwargs