예제 #1
0
def embed_and_write_batched(
    embedder: EmbedderInterface,
    file_manager: FileManagerInterface,
    result_kwargs: Dict[str, Any],
    half_precision: bool = False
) -> Dict[str, Any]:
    """ The shared code between the SeqVec, Albert, Bert and XLNet pipelines """
    # Lazy fasta file reader. The mapping file contains the corresponding ids in the same order
    sequences = (
        str(entry.seq)
        for entry in SeqIO.parse(result_kwargs["remapped_sequences_file"], "fasta")
    )
    # We want to read the unnamed column 0 as str (esp. with simple_remapping), which requires some workarounds
    # https://stackoverflow.com/a/29793294/3549270
    mapping_file = read_csv(result_kwargs["mapping_file"], index_col=0)
    mapping_file.index = mapping_file.index.astype('str')

    # Print the minimum required file sizes
    _print_expected_file_sizes(embedder, mapping_file, result_kwargs)

    # Get transformer function, if available
    transform_function = result_kwargs.get("embeddings_transformer_function", None)

    if transform_function:
        transform_function = eval(transform_function, {}, {"np": numpy})

    # Open embedding files or null contexts and iteratively save embeddings to file
    with _get_embeddings_file_context(
        file_manager, result_kwargs
    ) as embeddings_file, _get_reduced_embeddings_file_context(
        file_manager, result_kwargs
    ) as reduced_embeddings_file, _get_transformed_embeddings_file_context(
        file_manager, result_kwargs
    ) as transformed_embeddings_file:
        embedding_generator = embedder.embed_many(
            sequences, result_kwargs.get("max_amino_acids")
        )
        for sequence_id, original_id, embedding in zip(
            mapping_file.index,
            mapping_file["original_id"],
            tqdm(embedding_generator, total=len(mapping_file))
        ):
            # embedding: numpy.ndarray
            if half_precision:
                embedding = embedding.astype(numpy.float16)
            if result_kwargs.get("discard_per_amino_acid_embeddings") is False:
                dataset = embeddings_file.create_dataset(sequence_id, data=embedding)
                dataset.attrs["original_id"] = original_id
            if result_kwargs.get("reduce") is True:
                dataset = reduced_embeddings_file.create_dataset(
                    sequence_id, data=embedder.reduce_per_protein(embedding)
                )
                dataset.attrs["original_id"] = original_id
            if transform_function:
                dataset = transformed_embeddings_file.create_dataset(
                    sequence_id, data=numpy.array(transform_function(embedding))
                )
                dataset.attrs["original_id"] = original_id

    return result_kwargs
예제 #2
0
def _check_transform_embeddings_function(embedder: EmbedderInterface, result_kwargs: Dict[str, Any]):

    result_kwargs.setdefault("embeddings_transformer_function", None)

    if result_kwargs["embeddings_transformer_function"] is not None:
        try:
            transform_function = eval(result_kwargs["embeddings_transformer_function"], {}, {"np": numpy})
        except TypeError:
            raise InvalidParameterError(f"`embeddings_transformer_function` must be callable! \n"
                                        f"Instead is {result_kwargs['embeddings_transformer_function']}\n"
                                        f"Most likely you want a lambda function.")

        if not callable(transform_function):
            raise InvalidParameterError(f"`embeddings_transformer_function` must be callable! \n"
                                        f"Instead is {result_kwargs['embeddings_transformer_function']}\n"
                                        f"Most likely you want a lambda function.")

        template_embedding = embedder.embed("SEQVENCE")

        # Check that it works in principle
        try:
            transformed_template_embedding = transform_function(template_embedding)
        except:
            raise InvalidParameterError(f"`embeddings_transformer_function` must be valid callable! \n"
                                        f"Instead is {result_kwargs['embeddings_transformer_function']}\n"
                                        f"This function excepts when processing an embedding.")

        # Check that return can be cast to np.array
        try:
            numpy.array(transformed_template_embedding)
        except:
            raise InvalidParameterError(f"`embeddings_transformer_function` must be valid callable "
                                        f"returning numpy array compatible object! \n"
                                        f"Instead is {result_kwargs['embeddings_transformer_function']}\n"
                                        f"This function excepts when processing an embedding.")
예제 #3
0
def check_embedding(embedder: EmbedderInterface, embedding, sequence: str):
    """Checks that the shape of the embeddings looks credible"""
    assert isinstance(embedding, ndarray)
    if embedder.__class__ == SeqVecEmbedder:
        assert embedding.shape[1] == len(sequence)
    elif embedder.__class__ == UniRepEmbedder:
        # See https://github.com/ElArkk/jax-unirep/issues/85
        assert embedding.shape[0] == len(sequence) + 1
    elif embedder.__class__ == CPCProtEmbedder:
        # There is only a per-protein embedding for CPCProt
        assert embedding.shape == (512,)
    else:
        assert embedding.shape[0] == len(sequence)

    # Check reduce_per_protein
    # https://github.com/sacdallago/bio_embeddings/issues/85
    assert embedder.reduce_per_protein(embedding).shape == (
        embedder.embedding_dimension,
    )
예제 #4
0
def check_embedding(embedder: EmbedderInterface, embedding, sequence: str):
    """Checks that the shape of the embeddings looks credible"""
    # assert isinstance(embedding, ndarray) # TODO: Fix unirep and reenable
    if embedder.__class__ == SeqVecEmbedder:
        assert embedding.shape[1] == len(sequence)
    elif embedder.__class__ == UniRepEmbedder:
        # Not sure why this is one longer, but the jax-unirep tests check
        # `len(sequence) + 1`, so it seems to be intended
        assert embedding.shape[0] == len(sequence) + 1
    elif embedder.__class__ == CPCProtEmbedder:
        # There is only a per-protein embedding for CPCProt
        assert embedding.shape == (512, )
    else:
        assert embedding.shape[0] == len(sequence)

    # Check reduce_per_protein
    # https://github.com/sacdallago/bio_embeddings/issues/85
    assert embedder.reduce_per_protein(embedding).shape == (
        embedder.embedding_dimension, )
예제 #5
0
def test_foo(sequence, cached_embedder: EmbedderInterface):
    embedding = cached_embedder.embed(sequence)
    check_embedding(cached_embedder, embedding, sequence)