def embed_and_write_batched( embedder: EmbedderInterface, file_manager: FileManagerInterface, result_kwargs: Dict[str, Any], half_precision: bool = False ) -> Dict[str, Any]: """ The shared code between the SeqVec, Albert, Bert and XLNet pipelines """ # Lazy fasta file reader. The mapping file contains the corresponding ids in the same order sequences = ( str(entry.seq) for entry in SeqIO.parse(result_kwargs["remapped_sequences_file"], "fasta") ) # We want to read the unnamed column 0 as str (esp. with simple_remapping), which requires some workarounds # https://stackoverflow.com/a/29793294/3549270 mapping_file = read_csv(result_kwargs["mapping_file"], index_col=0) mapping_file.index = mapping_file.index.astype('str') # Print the minimum required file sizes _print_expected_file_sizes(embedder, mapping_file, result_kwargs) # Get transformer function, if available transform_function = result_kwargs.get("embeddings_transformer_function", None) if transform_function: transform_function = eval(transform_function, {}, {"np": numpy}) # Open embedding files or null contexts and iteratively save embeddings to file with _get_embeddings_file_context( file_manager, result_kwargs ) as embeddings_file, _get_reduced_embeddings_file_context( file_manager, result_kwargs ) as reduced_embeddings_file, _get_transformed_embeddings_file_context( file_manager, result_kwargs ) as transformed_embeddings_file: embedding_generator = embedder.embed_many( sequences, result_kwargs.get("max_amino_acids") ) for sequence_id, original_id, embedding in zip( mapping_file.index, mapping_file["original_id"], tqdm(embedding_generator, total=len(mapping_file)) ): # embedding: numpy.ndarray if half_precision: embedding = embedding.astype(numpy.float16) if result_kwargs.get("discard_per_amino_acid_embeddings") is False: dataset = embeddings_file.create_dataset(sequence_id, data=embedding) dataset.attrs["original_id"] = original_id if result_kwargs.get("reduce") is True: dataset = reduced_embeddings_file.create_dataset( sequence_id, data=embedder.reduce_per_protein(embedding) ) dataset.attrs["original_id"] = original_id if transform_function: dataset = transformed_embeddings_file.create_dataset( sequence_id, data=numpy.array(transform_function(embedding)) ) dataset.attrs["original_id"] = original_id return result_kwargs
def _check_transform_embeddings_function(embedder: EmbedderInterface, result_kwargs: Dict[str, Any]): result_kwargs.setdefault("embeddings_transformer_function", None) if result_kwargs["embeddings_transformer_function"] is not None: try: transform_function = eval(result_kwargs["embeddings_transformer_function"], {}, {"np": numpy}) except TypeError: raise InvalidParameterError(f"`embeddings_transformer_function` must be callable! \n" f"Instead is {result_kwargs['embeddings_transformer_function']}\n" f"Most likely you want a lambda function.") if not callable(transform_function): raise InvalidParameterError(f"`embeddings_transformer_function` must be callable! \n" f"Instead is {result_kwargs['embeddings_transformer_function']}\n" f"Most likely you want a lambda function.") template_embedding = embedder.embed("SEQVENCE") # Check that it works in principle try: transformed_template_embedding = transform_function(template_embedding) except: raise InvalidParameterError(f"`embeddings_transformer_function` must be valid callable! \n" f"Instead is {result_kwargs['embeddings_transformer_function']}\n" f"This function excepts when processing an embedding.") # Check that return can be cast to np.array try: numpy.array(transformed_template_embedding) except: raise InvalidParameterError(f"`embeddings_transformer_function` must be valid callable " f"returning numpy array compatible object! \n" f"Instead is {result_kwargs['embeddings_transformer_function']}\n" f"This function excepts when processing an embedding.")
def check_embedding(embedder: EmbedderInterface, embedding, sequence: str): """Checks that the shape of the embeddings looks credible""" assert isinstance(embedding, ndarray) if embedder.__class__ == SeqVecEmbedder: assert embedding.shape[1] == len(sequence) elif embedder.__class__ == UniRepEmbedder: # See https://github.com/ElArkk/jax-unirep/issues/85 assert embedding.shape[0] == len(sequence) + 1 elif embedder.__class__ == CPCProtEmbedder: # There is only a per-protein embedding for CPCProt assert embedding.shape == (512,) else: assert embedding.shape[0] == len(sequence) # Check reduce_per_protein # https://github.com/sacdallago/bio_embeddings/issues/85 assert embedder.reduce_per_protein(embedding).shape == ( embedder.embedding_dimension, )
def check_embedding(embedder: EmbedderInterface, embedding, sequence: str): """Checks that the shape of the embeddings looks credible""" # assert isinstance(embedding, ndarray) # TODO: Fix unirep and reenable if embedder.__class__ == SeqVecEmbedder: assert embedding.shape[1] == len(sequence) elif embedder.__class__ == UniRepEmbedder: # Not sure why this is one longer, but the jax-unirep tests check # `len(sequence) + 1`, so it seems to be intended assert embedding.shape[0] == len(sequence) + 1 elif embedder.__class__ == CPCProtEmbedder: # There is only a per-protein embedding for CPCProt assert embedding.shape == (512, ) else: assert embedding.shape[0] == len(sequence) # Check reduce_per_protein # https://github.com/sacdallago/bio_embeddings/issues/85 assert embedder.reduce_per_protein(embedding).shape == ( embedder.embedding_dimension, )
def test_foo(sequence, cached_embedder: EmbedderInterface): embedding = cached_embedder.embed(sequence) check_embedding(cached_embedder, embedding, sequence)