def _check_transform_embeddings_function(embedder: EmbedderInterface, result_kwargs: Dict[str, Any]): result_kwargs.setdefault("embeddings_transformer_function", None) if result_kwargs["embeddings_transformer_function"] is not None: try: transform_function = eval(result_kwargs["embeddings_transformer_function"], {}, {"np": numpy}) except TypeError: raise InvalidParameterError(f"`embeddings_transformer_function` must be callable! \n" f"Instead is {result_kwargs['embeddings_transformer_function']}\n" f"Most likely you want a lambda function.") if not callable(transform_function): raise InvalidParameterError(f"`embeddings_transformer_function` must be callable! \n" f"Instead is {result_kwargs['embeddings_transformer_function']}\n" f"Most likely you want a lambda function.") template_embedding = embedder.embed("SEQVENCE") # Check that it works in principle try: transformed_template_embedding = transform_function(template_embedding) except: raise InvalidParameterError(f"`embeddings_transformer_function` must be valid callable! \n" f"Instead is {result_kwargs['embeddings_transformer_function']}\n" f"This function excepts when processing an embedding.") # Check that return can be cast to np.array try: numpy.array(transformed_template_embedding) except: raise InvalidParameterError(f"`embeddings_transformer_function` must be valid callable " f"returning numpy array compatible object! \n" f"Instead is {result_kwargs['embeddings_transformer_function']}\n" f"This function excepts when processing an embedding.")
def run(**kwargs): """ Run embedding protocol Parameters ---------- kwargs arguments (* denotes optional): sequences_file: Where sequences live prefix: Output prefix for all generated files protocol: Which embedder to use mapping_file: the mapping file generated by the pipeline when remapping indexes stage_name: The stage name Returns ------- Dictionary with results of stage """ check_required( kwargs, [ "protocol", "prefix", "stage_name", "remapped_sequences_file", "mapping_file" ], ) if kwargs["protocol"] not in PROTOCOLS: raise InvalidParameterError( "Invalid protocol selection: {}. Valid protocols are: {}".format( kwargs["protocol"], ", ".join(PROTOCOLS.keys()))) embedder_class = PROTOCOLS[kwargs["protocol"]] if embedder_class == UniRepEmbedder and kwargs.get("use_cpu") is not None: raise InvalidParameterError( "UniRep does not support configuring `use_cpu`") result_kwargs = deepcopy(kwargs) # Download necessary files if needed # noinspection PyProtectedMember for file in embedder_class._necessary_files: if not result_kwargs.get(file): result_kwargs[file] = get_model_file(model=embedder_class.name, file=file) # noinspection PyProtectedMember for directory in embedder_class._necessary_directories: if not result_kwargs.get(directory): result_kwargs[directory] = get_model_directories_from_zip( model=embedder_class.name, directory=directory) result_kwargs.setdefault("max_amino_acids", DEFAULT_MAX_AMINO_ACIDS[kwargs["protocol"]]) file_manager = get_file_manager(**kwargs) embedder: EmbedderInterface = embedder_class(**result_kwargs) _check_transform_embeddings_function(embedder, result_kwargs) return embed_and_write_batched(embedder, file_manager, result_kwargs, kwargs.get("half_precision", False))
def run(**kwargs): """ Run visualize protocol Parameters ---------- kwargs arguments (* denotes optional): projected_reduced_embeddings_file: A csv with columns: (index), original_id, x, y, z prefix: Output prefix for all generated files stage_name: The stage name protocol: Which plot to generate For plotly: projected_reduced_embeddings_file: The projected (dimensionality reduced) embeddings, normally coming from the project stage annotation_file: csv file with annotations display_unknown: Hide proteins for which there is no annotation in the annotation file (only relevant if annotation file is provided) merge_via_index: Set to True if in annotation_file identifiers correspond to sequence MD5 hashes n_components: 2D vs 3D plot For plot_mutagenesis: residue_probabilities_file: The csv with the probabilities, normally coming from the mutagenesis stage Returns ------- Dictionary with results of stage """ check_required(kwargs, ['protocol', 'prefix', 'stage_name']) if kwargs["protocol"] not in PROTOCOLS: raise InvalidParameterError( "Invalid protocol selection: " + "{}. Valid protocols are: {}".format(kwargs["protocol"], ", ".join( PROTOCOLS.keys()))) result_kwargs = deepcopy(kwargs) if kwargs["protocol"] == "plotly": # Support legacy projected_embeddings_file projected_reduced_embeddings_file = ( kwargs.get("projected_reduced_embeddings_file") or kwargs.get("projected_embeddings_file")) if not projected_reduced_embeddings_file: raise InvalidParameterError( f"You need to provide either projected_reduced_embeddings_file or projected_embeddings_file or " f"reduced_embeddings_file for {kwargs['protocol']}") result_kwargs[ "projected_reduced_embeddings_file"] = projected_reduced_embeddings_file return PROTOCOLS[kwargs["protocol"]](result_kwargs)
def run(**kwargs): """ Run project protocol Parameters ---------- kwargs arguments (* denotes optional): projected_reduced_embeddings_file or projected_embeddings_file or reduced_embeddings_file: Where per-protein embeddings live prefix: Output prefix for all generated files stage_name: The stage name protocol: Which projection technique to use mapping_file: the mapping file generated by the pipeline when remapping indexes Returns ------- Dictionary with results of stage """ check_required(kwargs, ["protocol", "prefix", "stage_name", "mapping_file"]) if kwargs["protocol"] not in PROTOCOLS: raise InvalidParameterError( "Invalid protocol selection: " + "{}. Valid protocols are: {}".format( kwargs["protocol"], ", ".join(PROTOCOLS.keys()) ) ) result_kwargs = deepcopy(kwargs) # We want to allow chaining protocols, e.g. first tucker than umap, # so we need to allow projected embeddings as input embeddings_input_file = ( kwargs.get("projected_reduced_embeddings_file") or kwargs.get("projected_embeddings_file") or kwargs.get("reduced_embeddings_file") ) if not embeddings_input_file: raise InvalidParameterError( f"You need to provide either projected_reduced_embeddings_file or projected_embeddings_file or " f"reduced_embeddings_file for {kwargs['protocol']}" ) result_kwargs["reduced_embeddings_file"] = embeddings_input_file file_manager = get_file_manager(**kwargs) result_kwargs = PROTOCOLS[kwargs["protocol"]](file_manager, result_kwargs) return result_kwargs
def run(**kwargs): """ Run project protocol Parameters ---------- kwargs arguments (* denotes optional): reduced_embeddings_file: Where per-protein embeddings live prefix: Output prefix for all generated files stage_name: The stage name protocol: Which projection technique to use mapping_file: the mapping file generated by the pipeline when remapping indexes Returns ------- Dictionary with results of stage """ check_required(kwargs, [ 'protocol', 'prefix', 'stage_name', 'reduced_embeddings_file', 'mapping_file' ]) if kwargs["protocol"] not in PROTOCOLS: raise InvalidParameterError( "Invalid protocol selection: " + "{}. Valid protocols are: {}".format(kwargs["protocol"], ", ".join( PROTOCOLS.keys()))) return PROTOCOLS[kwargs["protocol"]](**kwargs)
def run(**kwargs): """ Run visualize protocol Parameters ---------- kwargs arguments (* denotes optional): projected_embeddings_file: A csv with columns: (index), original_id, x, y, z prefix: Output prefix for all generated files stage_name: The stage name protocol: Which plot to generate Returns ------- Dictionary with results of stage """ check_required(kwargs, ['protocol', 'prefix', 'stage_name', 'projected_embeddings_file']) if kwargs["protocol"] not in PROTOCOLS: raise InvalidParameterError( "Invalid protocol selection: " + "{}. Valid protocols are: {}".format( kwargs["protocol"], ", ".join(PROTOCOLS.keys()) ) ) return PROTOCOLS[kwargs["protocol"]](**kwargs)
def _get_embeddings_file_context( file_manager: FileManagerInterface, result_kwargs: Dict[str, Any] ): """ :param file_manager: The FileManager derived class which will be used to create the file :param result_kwargs: A dictionary which will be updated in-place to include the path to the newly created file :return: a file context """ result_kwargs.setdefault("discard_per_amino_acid_embeddings", False) if result_kwargs["discard_per_amino_acid_embeddings"] is True: if result_kwargs.get("reduce", False) is False and result_kwargs.get("embeddings_transformer_function") is None: raise InvalidParameterError( "Cannot only have discard_per_amino_acid_embeddings: True. " "Either also set `reduce: True` or define an `embeddings_transformer_function`, or both." ) return nullcontext() else: embeddings_file_path = file_manager.create_file( result_kwargs.get("prefix"), result_kwargs.get("stage_name"), "embeddings_file", extension=".h5", ) result_kwargs["embeddings_file"] = embeddings_file_path return h5py.File(embeddings_file_path, "w")
def _validate_file(file_path: str): """ Verify if a file exists and is not empty. Parameters ---------- file_path : str Path to file to check Returns ------- bool True if file exists and is non-zero size, False otherwise. """ try: if os.stat(file_path).st_size == 0: raise InvalidParameterError(f"The file at '{file_path}' is empty") except (OSError, TypeError) as e: raise InvalidParameterError( f"The configuration file at '{file_path}' does not exist") from e
def run(**kwargs): """ Run visualize protocol Parameters ---------- kwargs arguments (* denotes optional): projected_reduced_embeddings_file: A csv with columns: (index), original_id, x, y, z prefix: Output prefix for all generated files stage_name: The stage name protocol: Which plot to generate Returns ------- Dictionary with results of stage """ check_required(kwargs, ['protocol', 'prefix', 'stage_name']) if kwargs["protocol"] not in PROTOCOLS: raise InvalidParameterError( "Invalid protocol selection: " + "{}. Valid protocols are: {}".format( kwargs["protocol"], ", ".join(PROTOCOLS.keys()) ) ) result_kwargs = deepcopy(kwargs) # Support legacy projected_embeddings_file projected_reduced_embeddings_file = ( kwargs.get("projected_reduced_embeddings_file") or kwargs.get("projected_embeddings_file") ) if not projected_reduced_embeddings_file: raise InvalidParameterError( f"You need to provide either projected_reduced_embeddings_file or projected_embeddings_file or " f"reduced_embeddings_file for {kwargs['protocol']}" ) result_kwargs["projected_reduced_embeddings_file"] = projected_reduced_embeddings_file return PROTOCOLS[kwargs["protocol"]](result_kwargs)
def prepare_kwargs(**kwargs): required_kwargs = [ "protocol", "prefix", "stage_name", "remapped_sequences_file", "mapping_file", ] check_required(kwargs, required_kwargs) if kwargs["protocol"] not in name_to_embedder: if kwargs["protocol"] in ALL_PROTOCOLS: raise InvalidParameterError( f"The extra for the protocol {kwargs['protocol']} is missing. " "See https://docs.bioembeddings.com/#installation on how to install all extras" ) raise InvalidParameterError( "Invalid protocol selection: {}. Valid protocols are: {}".format( kwargs["protocol"], ", ".join(name_to_embedder.keys()))) embedder_class = name_to_embedder[kwargs["protocol"]] if kwargs["protocol"] == "unirep" and kwargs.get("use_cpu") is not None: raise InvalidParameterError( "UniRep does not support configuring `use_cpu`") # See parameter_blueprints.yml global_options = {"sequences_file", "simple_remapping", "start_time"} embed_options = { "decoder", "device", "discard_per_amino_acid_embeddings", "half_precision_model", "half_precision", "max_amino_acids", "reduce", "type", } known_parameters = (set(required_kwargs) | global_options | embed_options | set(embedder_class.necessary_files) | set(embedder_class.necessary_directories)) if embedder_class == "seqvec": # We support two ways of configuration for seqvec known_parameters.add("model_directory") if not set(kwargs) < known_parameters: # Complain louder if the input looks fishier for option in set(kwargs) - known_parameters: logger.warning( f"You set an unknown option for {embedder_class.name}: {option} (value: {kwargs[option]})" ) if kwargs.get("half_precision_model"): if kwargs["protocol"] not in [ "prottrans_t5_bfd", "prottrans_t5_uniref50" ]: raise InvalidParameterError( "`half_precision_model` is only supported with prottrans_t5_bfd and prottrans_t5_uniref50" ) if kwargs.get("half_precision") is False: # None remains allowed raise InvalidParameterError( "You can't have `half_precision_model` be true and `half_precision` be false. " "We suggest also setting `half_precision` to true, " "which will compute and save embeddings as half-precision floats" ) result_kwargs = deepcopy(kwargs) result_kwargs.setdefault("max_amino_acids", DEFAULT_MAX_AMINO_ACIDS[kwargs["protocol"]]) return embedder_class, result_kwargs
def plotly(result_kwargs: Dict[str, Any]) -> Dict[str, Any]: file_manager = get_file_manager(**result_kwargs) # 2 or 3D plot? Usually, this is directly fetched from "project" stage via "depends_on" result_kwargs['n_components'] = result_kwargs.get('n_components', 3) if result_kwargs['n_components'] < 2: raise TooFewComponentsException(f"n_components is set to {result_kwargs['n_components']}. It should be >1.\n" f"If set to 2, will render 2D scatter plot.\n" f"If set to >3, will render 3D scatter plot.") # Get projected_embeddings_file containing x,y,z coordinates and identifiers suffix = Path(result_kwargs["projected_reduced_embeddings_file"]).suffix if suffix == ".csv": # Support the legacy csv format merged_annotation_file = read_csv( result_kwargs["projected_reduced_embeddings_file"], index_col=0 ) elif suffix == ".h5": # convert h5 to dataframe with ids and one column per dimension rows = [] with h5py.File(result_kwargs["projected_reduced_embeddings_file"], "r") as file: for sequence_id, embedding in file.items(): if embedding.shape != (3,): raise RuntimeError( f"Expected embeddings in projected_reduced_embeddings_file " f"to be of shape (3,), not {embedding.shape}" ) row = ( sequence_id, embedding.attrs["original_id"], embedding[0], embedding[1], embedding[2], ) rows.append(row) columns = [ "sequence_id", "original_id", "component_0", "component_1", "component_2", ] merged_annotation_file = DataFrame.from_records( rows, index="sequence_id", columns=columns ) else: raise InvalidParameterError( f"Expected .csv or .h5 as suffix for projected_reduced_embeddings_file, got {suffix}" ) if result_kwargs.get('annotation_file'): annotation_file = read_csv(result_kwargs['annotation_file']).set_index('identifier') # Save a copy of the annotation file with index set to identifier input_annotation_file_path = file_manager.create_file(result_kwargs.get('prefix'), result_kwargs.get('stage_name'), 'input_annotation_file', extension='.csv') annotation_file.to_csv(input_annotation_file_path) # Merge annotation file and projected embedding file based on index or original id? result_kwargs['merge_via_index'] = result_kwargs.get('merge_via_index', False) # Display proteins with unknown annotation? result_kwargs['display_unknown'] = result_kwargs.get('display_unknown', True) if result_kwargs['merge_via_index']: if result_kwargs['display_unknown']: merged_annotation_file = annotation_file.join(merged_annotation_file, how="outer") merged_annotation_file['label'].fillna('UNKNOWN', inplace=True) else: merged_annotation_file = annotation_file.join(merged_annotation_file) else: if result_kwargs['display_unknown']: merged_annotation_file = annotation_file.join(merged_annotation_file.set_index('original_id'), how="outer") merged_annotation_file['label'].fillna('UNKNOWN', inplace=True) else: merged_annotation_file = annotation_file.join(merged_annotation_file.set_index('original_id')) else: merged_annotation_file['label'] = 'UNKNOWN' merged_annotation_file_path = file_manager.create_file(result_kwargs.get('prefix'), result_kwargs.get('stage_name'), 'merged_annotation_file', extension='.csv') merged_annotation_file.to_csv(merged_annotation_file_path) result_kwargs['merged_annotation_file'] = merged_annotation_file_path if result_kwargs['n_components'] == 2: figure = render_scatter_plotly(merged_annotation_file) else: figure = render_3D_scatter_plotly(merged_annotation_file) plot_file_path = file_manager.create_file(result_kwargs.get('prefix'), result_kwargs.get('stage_name'), 'plot_file', extension='.html') save_plotly_figure_to_html(figure, plot_file_path) result_kwargs['plot_file'] = plot_file_path return result_kwargs