def run(**kwargs): """ Run embedding protocol Parameters ---------- kwargs arguments (* denotes optional): sequences_file: Where sequences live prefix: Output prefix for all generated files protocol: Which embedder to use mapping_file: the mapping file generated by the pipeline when remapping indexes stage_name: The stage name Returns ------- Dictionary with results of stage """ check_required( kwargs, [ "protocol", "prefix", "stage_name", "remapped_sequences_file", "mapping_file" ], ) if kwargs["protocol"] not in PROTOCOLS: raise InvalidParameterError( "Invalid protocol selection: {}. Valid protocols are: {}".format( kwargs["protocol"], ", ".join(PROTOCOLS.keys()))) embedder_class = PROTOCOLS[kwargs["protocol"]] if embedder_class == UniRepEmbedder and kwargs.get("use_cpu") is not None: raise InvalidParameterError( "UniRep does not support configuring `use_cpu`") result_kwargs = deepcopy(kwargs) # Download necessary files if needed # noinspection PyProtectedMember for file in embedder_class._necessary_files: if not result_kwargs.get(file): result_kwargs[file] = get_model_file(model=embedder_class.name, file=file) # noinspection PyProtectedMember for directory in embedder_class._necessary_directories: if not result_kwargs.get(directory): result_kwargs[directory] = get_model_directories_from_zip( model=embedder_class.name, directory=directory) result_kwargs.setdefault("max_amino_acids", DEFAULT_MAX_AMINO_ACIDS[kwargs["protocol"]]) file_manager = get_file_manager(**kwargs) embedder: EmbedderInterface = embedder_class(**result_kwargs) _check_transform_embeddings_function(embedder, result_kwargs) return embed_and_write_batched(embedder, file_manager, result_kwargs, kwargs.get("half_precision", False))
def run(**kwargs): """ Run embedding protocol Parameters ---------- kwargs arguments (* denotes optional): sequences_file: Where sequences live prefix: Output prefix for all generated files protocol: Which embedder to use mapping_file: the mapping file generated by the pipeline when remapping indexes stage_name: The stage name Returns ------- Dictionary with results of stage """ embedder_class, result_kwargs = prepare_kwargs(**kwargs) file_manager = get_file_manager(**kwargs) embedder: EmbedderInterface = embedder_class(**result_kwargs) _check_transform_embeddings_function(embedder, result_kwargs) return embed_and_write_batched(embedder, file_manager, result_kwargs, kwargs.get("half_precision", False))
def _process_fasta_file(**kwargs): """ Will assign MD5 hash as ID if no if provided for a sequence. """ result_kwargs = deepcopy(kwargs) file_manager = get_file_manager(**kwargs) sequences = read_fasta(kwargs['sequences_file']) sequences_file_path = file_manager.create_file(kwargs.get('prefix'), None, 'sequences_file', extension='.fasta') write_fasta_file(sequences, sequences_file_path) result_kwargs['sequences_file'] = sequences_file_path # Remap using sequence position rather than md5 hash -- not encouraged! result_kwargs['simple_remapping'] = result_kwargs.get( 'simple_remapping', False) mapping = reindex_sequences(sequences, simple=result_kwargs['simple_remapping']) # Check if there's the same MD5 index twice. This most likely indicates 100% sequence identity. # Throw an error for MD5 hash clashes! if mapping.index.has_duplicates: raise MD5ClashException( "There is at least one MD5 hash clash.\n" "This most likely indicates there are multiple identical sequences in your FASTA file.\n" "MD5 hashes are used to remap sequence identifiers from the input FASTA.\n" "This error exists to prevent wasting resources (computing the same embedding twice).\n" "There's a (very) low probability of this indicating a real MD5 clash.\n\n" "If you are sure there are no identical sequences in your set, please open an issue at " "https://github.com/sacdallago/bio_embeddings/issues . " "Otherwise, use cd-hit to reduce your input FASTA to exclude identical sequences!" ) mapping_file_path = file_manager.create_file(kwargs.get('prefix'), None, 'mapping_file', extension='.csv') remapped_sequence_file_path = file_manager.create_file( kwargs.get('prefix'), None, 'remapped_sequences_file', extension='.fasta') write_fasta_file(sequences, remapped_sequence_file_path) mapping.to_csv(mapping_file_path) result_kwargs['mapping_file'] = mapping_file_path result_kwargs['remapped_sequences_file'] = remapped_sequence_file_path return result_kwargs
def run(**kwargs): """ Run project protocol Parameters ---------- kwargs arguments (* denotes optional): projected_reduced_embeddings_file or projected_embeddings_file or reduced_embeddings_file: Where per-protein embeddings live prefix: Output prefix for all generated files stage_name: The stage name protocol: Which projection technique to use mapping_file: the mapping file generated by the pipeline when remapping indexes Returns ------- Dictionary with results of stage """ check_required(kwargs, ["protocol", "prefix", "stage_name", "mapping_file"]) if kwargs["protocol"] not in PROTOCOLS: raise InvalidParameterError( "Invalid protocol selection: " + "{}. Valid protocols are: {}".format( kwargs["protocol"], ", ".join(PROTOCOLS.keys()) ) ) result_kwargs = deepcopy(kwargs) # We want to allow chaining protocols, e.g. first tucker than umap, # so we need to allow projected embeddings as input embeddings_input_file = ( kwargs.get("projected_reduced_embeddings_file") or kwargs.get("projected_embeddings_file") or kwargs.get("reduced_embeddings_file") ) if not embeddings_input_file: raise InvalidParameterError( f"You need to provide either projected_reduced_embeddings_file or projected_embeddings_file or " f"reduced_embeddings_file for {kwargs['protocol']}" ) result_kwargs["reduced_embeddings_file"] = embeddings_input_file file_manager = get_file_manager(**kwargs) result_kwargs = PROTOCOLS[kwargs["protocol"]](file_manager, result_kwargs) return result_kwargs
def tsne(**kwargs): result_kwargs = deepcopy(kwargs) file_manager = get_file_manager(**kwargs) # Get sequence mapping to use as information source mapping = read_csv(result_kwargs['mapping_file'], index_col=0) reduced_embeddings_file_path = result_kwargs['reduced_embeddings_file'] reduced_embeddings = [] with h5py.File(reduced_embeddings_file_path, 'r') as f: for remapped_id in mapping.index: reduced_embeddings.append(np.array(f[str(remapped_id)])) # Get parameters or set defaults result_kwargs['perplexity'] = kwargs.get('perplexity', 6) result_kwargs['n_jobs'] = kwargs.get('n_jobs', -1) result_kwargs['n_iter'] = kwargs.get('n_iter', 15000) result_kwargs['metric'] = kwargs.get('metric', 'cosine') result_kwargs['n_components'] = kwargs.get('n_components', 3) result_kwargs['random_state'] = kwargs.get('random_state', 420) result_kwargs['verbose'] = kwargs.get('verbose', 1) projected_embeddings = tsne_reduce(reduced_embeddings, **kwargs) for i in range(result_kwargs['n_components']): mapping[f'component_{i}'] = projected_embeddings[:, i] projected_embeddings_file_path = file_manager.create_file( kwargs.get('prefix'), result_kwargs.get('stage_name'), 'projected_embeddings_file', extension='.csv') mapping.to_csv(projected_embeddings_file_path) result_kwargs['projected_embeddings_file'] = projected_embeddings_file_path return result_kwargs
def run(**kwargs): """ Run embedding protocol Parameters ---------- kwargs arguments (* denotes optional): sequences_file: Where sequences live prefix: Output prefix for all generated files protocol: Which embedder to use mapping_file: the mapping file generated by the pipeline when remapping indexes stage_name: The stage name Returns ------- Dictionary with results of stage """ embedder_class, result_kwargs = prepare_kwargs(**kwargs) # Download necessary files if needed # noinspection PyProtectedMember for file in embedder_class.necessary_files: if not result_kwargs.get(file): result_kwargs[file] = get_model_file(model=embedder_class.name, file=file) # noinspection PyProtectedMember for directory in embedder_class.necessary_directories: if not result_kwargs.get(directory): result_kwargs[directory] = get_model_directories_from_zip( model=embedder_class.name, directory=directory) file_manager = get_file_manager(**kwargs) embedder: EmbedderInterface = embedder_class(**result_kwargs) _check_transform_embeddings_function(embedder, result_kwargs) return embed_and_write_batched(embedder, file_manager, result_kwargs, kwargs.get("half_precision", False))
def plot_mutagenesis(result_kwargs): """BETA: visualize in-silico mutagenesis as a heatmap with plotly mandatory: * residue_probabilities_file """ required_kwargs = [ "protocol", "prefix", "stage_name", "residue_probabilities_file", ] check_required(result_kwargs, required_kwargs) file_manager = get_file_manager() file_manager.create_stage(result_kwargs["prefix"], result_kwargs["stage_name"]) probabilities_all = pandas.read_csv(result_kwargs["residue_probabilities_file"]) assert ( list(probabilities_all.columns) == PROBABILITIES_COLUMNS ), f"probabilities file is expected to have the following columns: {PROBABILITIES_COLUMNS}" number_of_proteins = len(set(probabilities_all["id"])) for sequence_id, probabilities in tqdm( probabilities_all.groupby("id"), total=number_of_proteins ): fig = plot(probabilities) plotly.offline.plot( fig, filename=file_manager.create_file( result_kwargs.get("prefix"), result_kwargs.get("stage_name"), sequence_id, extension=".html", ), ) return result_kwargs
def _process_fasta_file(**kwargs): """ Will assign MD5 hash as ID if no if provided for a sequence. """ result_kwargs = deepcopy(kwargs) file_manager = get_file_manager(**kwargs) sequences = read_fasta(kwargs['sequences_file']) # Sanity check the fasta file to avoid nonsense and/or crashes by the embedders letters = set(string.ascii_letters) for entry in sequences: illegal = sorted(set(entry.seq) - letters) if illegal: formatted = "'" + "', '".join(illegal) + "'" raise ValueError( f"The entry '{entry.name}' in {kwargs['sequences_file']} contains the characters {formatted}, " f"while only single letter code is allowed " f"(https://en.wikipedia.org/wiki/Amino_acid#Table_of_standard_amino_acid_abbreviations_and_properties)." ) # This is a warning due to the inconsistent handling between different embedders if not str(entry.seq).isupper(): logger.warning( f"The entry '{entry.name}' in {kwargs['sequences_file']} contains lower case amino acids. " f"Lower case letters are uninterpretable by most language models, " f"and their embedding will be nonesensical. " f"Protein LMs available through bio_embeddings have been trained on upper case, " f"single letter code sequence representations only " f"(https://en.wikipedia.org/wiki/Amino_acid#Table_of_standard_amino_acid_abbreviations_and_properties)." ) sequences_file_path = file_manager.create_file(kwargs.get('prefix'), None, 'sequences_file', extension='.fasta') write_fasta_file(sequences, sequences_file_path) result_kwargs['sequences_file'] = sequences_file_path # Remap using sequence position rather than md5 hash -- not encouraged! result_kwargs['simple_remapping'] = result_kwargs.get( 'simple_remapping', False) mapping = reindex_sequences(sequences, simple=result_kwargs['simple_remapping']) # Check if there's the same MD5 index twice. This most likely indicates 100% sequence identity. # Throw an error for MD5 hash clashes! if mapping.index.has_duplicates: raise MD5ClashException( "There is at least one MD5 hash clash.\n" "This most likely indicates there are multiple identical sequences in your FASTA file.\n" "MD5 hashes are used to remap sequence identifiers from the input FASTA.\n" "This error exists to prevent wasting resources (computing the same embedding twice).\n" "There's a (very) low probability of this indicating a real MD5 clash.\n\n" "If you are sure there are no identical sequences in your set, please open an issue at " "https://github.com/sacdallago/bio_embeddings/issues . " "Otherwise, use cd-hit to reduce your input FASTA to exclude identical sequences!" ) mapping_file_path = file_manager.create_file(kwargs.get('prefix'), None, 'mapping_file', extension='.csv') remapped_sequence_file_path = file_manager.create_file( kwargs.get('prefix'), None, 'remapped_sequences_file', extension='.fasta') write_fasta_file(sequences, remapped_sequence_file_path) mapping.to_csv(mapping_file_path) result_kwargs['mapping_file'] = mapping_file_path result_kwargs['remapped_sequences_file'] = remapped_sequence_file_path return result_kwargs
def execute_pipeline_from_config(config: Dict, post_stage: Callable[[Dict], None] = _null_function, **kwargs) -> Dict: original_config = deepcopy(config) check_required(config, ["global"]) # !! pop = remove from config! global_parameters = config.pop('global') check_required(global_parameters, ["prefix", "sequences_file"]) file_manager = get_file_manager(**global_parameters) # Make sure prefix exists prefix = global_parameters['prefix'] # If prefix already exists if file_manager.exists(prefix): if not kwargs.get('overwrite'): raise FileExistsError( "The prefix already exists & no overwrite option has been set.\n" "Either set --overwrite, or move data from the prefix.\n" "Prefix: {}".format(prefix)) else: # create the prefix file_manager.create_prefix(prefix) # Copy original config to prefix global_in = file_manager.create_file(prefix, None, _IN_CONFIG_NAME, extension='.yml') write_config_file(global_in, original_config) # This downloads sequences_file if required download_files_for_stage(global_parameters, file_manager, prefix) global_parameters = _process_fasta_file(**global_parameters) for stage_name in config: stage_parameters = config[stage_name] original_stage_parameters = dict(**stage_parameters) check_required(stage_parameters, ["protocol", "type"]) stage_type = stage_parameters['type'] stage_runnable = _STAGES.get(stage_type) if not stage_runnable: raise Exception( "No type defined, or invalid stage type defined: {}".format( stage_type)) # Prepare to run stage stage_parameters['stage_name'] = stage_name file_manager.create_stage(prefix, stage_name) stage_parameters = download_files_for_stage(stage_parameters, file_manager, prefix, stage_name) stage_dependency = stage_parameters.get('depends_on') if stage_dependency: if stage_dependency not in config: raise Exception( "Stage {} depends on {}, but dependency not found in config." .format(stage_name, stage_dependency)) stage_dependency_parameters = config.get(stage_dependency) stage_parameters = { **global_parameters, **stage_dependency_parameters, **stage_parameters } else: stage_parameters = {**global_parameters, **stage_parameters} # Register start time start_time = datetime.now().astimezone() stage_parameters['start_time'] = str(start_time) stage_in = file_manager.create_file(prefix, stage_name, _IN_CONFIG_NAME, extension='.yml') write_config_file(stage_in, stage_parameters) try: stage_output_parameters = stage_runnable(**stage_parameters) except Exception as e: # Tell the user which stage failed and show an url to report an error on github try: version = importlib_metadata.version("bio_embeddings") except PackageNotFoundError: version = "unknown" # Make a github flavored markdown table; the header is in the template parameter_table = "\n".join( f"{key}|{value}" for key, value in original_stage_parameters.items()) params = { # https://stackoverflow.com/a/35498685/3549270 "title": f"Protocol {original_stage_parameters['protocol']}: {type(e).__name__}: {e}", "body": _ERROR_REPORTING_TEMPLATE.format( version, torch.cuda.is_available(), parameter_table, traceback.format_exc(10), ), } print(traceback.format_exc(), file=sys.stderr) print( f"Consider reporting this error at this url: {_ISSUE_URL}?{urllib.parse.urlencode(params)}\n\n" f"Stage {stage_name} failed.", file=sys.stderr, ) sys.exit(1) # Register end time end_time = datetime.now().astimezone() stage_output_parameters['end_time'] = str(end_time) # Register elapsed time stage_output_parameters['elapsed_time'] = str(end_time - start_time) stage_out = file_manager.create_file(prefix, stage_name, _OUT_CONFIG_NAME, extension='.yml') write_config_file(stage_out, stage_output_parameters) # Store in global_out config for later retrieval (e.g. depends_on) config[stage_name] = stage_output_parameters # Execute post-stage function, if provided post_stage(stage_output_parameters) config['global'] = global_parameters try: config['global']['version'] = importlib_metadata.version( "bio_embeddings") except PackageNotFoundError: pass # :( global_out = file_manager.create_file(prefix, None, _OUT_CONFIG_NAME, extension='.yml') write_config_file(global_out, config) return config
def plotly(**kwargs): result_kwargs = deepcopy(kwargs) file_manager = get_file_manager(**kwargs) # 2 or 3D plot? Usually, this is directly fetched from "project" stage via "depends_on" result_kwargs['n_components'] = kwargs.get('n_components', 3) if result_kwargs['n_components'] < 2: raise TooFewComponentsException(f"n_components is set to {result_kwargs['n_components']}. It should be >1.\n" f"If set to 2, will render 2D scatter plot.\n" f"If set to >3, will render 3D scatter plot.") # Get projected_embeddings_file containing x,y,z coordinates and identifiers projected_embeddings_file = read_csv(result_kwargs['projected_embeddings_file'], index_col=0) if result_kwargs.get('annotation_file'): annotation_file = read_csv(result_kwargs['annotation_file']).set_index('identifier') # Save a copy of the annotation file with index set to identifier input_annotation_file_path = file_manager.create_file(kwargs.get('prefix'), result_kwargs.get('stage_name'), 'input_annotation_file', extension='.csv') annotation_file.to_csv(input_annotation_file_path) # Merge annotation file and projected embedding file based on index or original id? result_kwargs['merge_via_index'] = result_kwargs.get('merge_via_index', False) # Display proteins with unknown annotation? result_kwargs['display_unknown'] = result_kwargs.get('display_unknown', True) if result_kwargs['merge_via_index']: if result_kwargs['display_unknown']: merged_annotation_file = annotation_file.join(projected_embeddings_file, how="outer") merged_annotation_file['label'].fillna('UNKNOWN', inplace=True) else: merged_annotation_file = annotation_file.join(projected_embeddings_file) else: if result_kwargs['display_unknown']: merged_annotation_file = annotation_file.join(projected_embeddings_file.set_index('original_id'), how="outer") merged_annotation_file['label'].fillna('UNKNOWN', inplace=True) else: merged_annotation_file = annotation_file.join(projected_embeddings_file.set_index('original_id')) merged_annotation_file_path = file_manager.create_file(kwargs.get('prefix'), result_kwargs.get('stage_name'), 'merged_annotation_file', extension='.csv') merged_annotation_file.to_csv(merged_annotation_file_path) result_kwargs['merged_annotation_file'] = merged_annotation_file_path if result_kwargs['n_components'] == 2: figure = render_scatter_plotly(merged_annotation_file) else: figure = render_3D_scatter_plotly(merged_annotation_file) plot_file_path = file_manager.create_file(kwargs.get('prefix'), result_kwargs.get('stage_name'), 'plot_file', extension='.html') save_plotly_figure_to_html(figure, plot_file_path) result_kwargs['plot_file'] = plot_file_path return result_kwargs
def plotly(result_kwargs: Dict[str, Any]) -> Dict[str, Any]: file_manager = get_file_manager(**result_kwargs) # 2 or 3D plot? Usually, this is directly fetched from "project" stage via "depends_on" result_kwargs['n_components'] = result_kwargs.get('n_components', 3) if result_kwargs['n_components'] < 2: raise TooFewComponentsException(f"n_components is set to {result_kwargs['n_components']}. It should be >1.\n" f"If set to 2, will render 2D scatter plot.\n" f"If set to >3, will render 3D scatter plot.") # Get projected_embeddings_file containing x,y,z coordinates and identifiers suffix = Path(result_kwargs["projected_reduced_embeddings_file"]).suffix if suffix == ".csv": # Support the legacy csv format merged_annotation_file = read_csv( result_kwargs["projected_reduced_embeddings_file"], index_col=0 ) elif suffix == ".h5": # convert h5 to dataframe with ids and one column per dimension rows = [] with h5py.File(result_kwargs["projected_reduced_embeddings_file"], "r") as file: for sequence_id, embedding in file.items(): if embedding.shape != (3,): raise RuntimeError( f"Expected embeddings in projected_reduced_embeddings_file " f"to be of shape (3,), not {embedding.shape}" ) row = ( sequence_id, embedding.attrs["original_id"], embedding[0], embedding[1], embedding[2], ) rows.append(row) columns = [ "sequence_id", "original_id", "component_0", "component_1", "component_2", ] merged_annotation_file = DataFrame.from_records( rows, index="sequence_id", columns=columns ) else: raise InvalidParameterError( f"Expected .csv or .h5 as suffix for projected_reduced_embeddings_file, got {suffix}" ) if result_kwargs.get('annotation_file'): annotation_file = read_csv(result_kwargs['annotation_file']).set_index('identifier') # Save a copy of the annotation file with index set to identifier input_annotation_file_path = file_manager.create_file(result_kwargs.get('prefix'), result_kwargs.get('stage_name'), 'input_annotation_file', extension='.csv') annotation_file.to_csv(input_annotation_file_path) # Merge annotation file and projected embedding file based on index or original id? result_kwargs['merge_via_index'] = result_kwargs.get('merge_via_index', False) # Display proteins with unknown annotation? result_kwargs['display_unknown'] = result_kwargs.get('display_unknown', True) if result_kwargs['merge_via_index']: if result_kwargs['display_unknown']: merged_annotation_file = annotation_file.join(merged_annotation_file, how="outer") merged_annotation_file['label'].fillna('UNKNOWN', inplace=True) else: merged_annotation_file = annotation_file.join(merged_annotation_file) else: if result_kwargs['display_unknown']: merged_annotation_file = annotation_file.join(merged_annotation_file.set_index('original_id'), how="outer") merged_annotation_file['label'].fillna('UNKNOWN', inplace=True) else: merged_annotation_file = annotation_file.join(merged_annotation_file.set_index('original_id')) else: merged_annotation_file['label'] = 'UNKNOWN' merged_annotation_file_path = file_manager.create_file(result_kwargs.get('prefix'), result_kwargs.get('stage_name'), 'merged_annotation_file', extension='.csv') merged_annotation_file.to_csv(merged_annotation_file_path) result_kwargs['merged_annotation_file'] = merged_annotation_file_path if result_kwargs['n_components'] == 2: figure = render_scatter_plotly(merged_annotation_file) else: figure = render_3D_scatter_plotly(merged_annotation_file) plot_file_path = file_manager.create_file(result_kwargs.get('prefix'), result_kwargs.get('stage_name'), 'plot_file', extension='.html') save_plotly_figure_to_html(figure, plot_file_path) result_kwargs['plot_file'] = plot_file_path return result_kwargs
def execute_pipeline_from_config(config: Dict, post_stage: Callable[[Dict], None] = _null_function, **kwargs) -> Dict: original_config = deepcopy(config) check_required(config, ["global"]) # !! pop = remove from config! global_parameters = config.pop('global') check_required(global_parameters, ["prefix", "sequences_file"]) file_manager = get_file_manager(**global_parameters) # Make sure prefix exists prefix = global_parameters['prefix'] # If prefix already exists if file_manager.exists(prefix): if not kwargs.get('overwrite'): raise FileExistsError( "The prefix already exists & no overwrite option has been set.\n" "Either set --overwrite, or move data from the prefix.\n" "Prefix: {}".format(prefix)) else: # create the prefix file_manager.create_prefix(prefix) try: Path(prefix).joinpath("bio_embeddings_version.txt").write_text( importlib_metadata.version("bio_embeddings")) except PackageNotFoundError: pass # :( # Copy original config to prefix global_in = file_manager.create_file(prefix, None, _IN_CONFIG_NAME, extension='.yml') write_config_file(global_in, original_config) global_parameters = _process_fasta_file(**global_parameters) for stage_name in config: stage_parameters = config[stage_name] check_required(stage_parameters, ["protocol", "type"]) stage_type = stage_parameters['type'] stage_runnable = _STAGES.get(stage_type) if not stage_runnable: raise Exception( "No type defined, or invalid stage type defined: {}".format( stage_type)) # Prepare to run stage stage_parameters['stage_name'] = stage_name file_manager.create_stage(prefix, stage_name) stage_dependency = stage_parameters.get('depends_on') if stage_dependency: if stage_dependency not in config: raise Exception( "Stage {} depends on {}, but dependency not found in config." .format(stage_name, stage_dependency)) stage_dependency_parameters = config.get(stage_dependency) stage_parameters = { **global_parameters, **stage_dependency_parameters, **stage_parameters } else: stage_parameters = {**global_parameters, **stage_parameters} # Register start time start_time = datetime.now().astimezone() stage_parameters['start_time'] = str(start_time) stage_in = file_manager.create_file(prefix, stage_name, _IN_CONFIG_NAME, extension='.yml') write_config_file(stage_in, stage_parameters) stage_output_parameters = stage_runnable(**stage_parameters) # Register end time end_time = datetime.now().astimezone() stage_output_parameters['end_time'] = str(end_time) # Register elapsed time stage_output_parameters['elapsed_time'] = str(end_time - start_time) stage_out = file_manager.create_file(prefix, stage_name, _OUT_CONFIG_NAME, extension='.yml') write_config_file(stage_out, stage_output_parameters) # Store in global_out config for later retrieval (e.g. depends_on) config[stage_name] = stage_output_parameters # Execute post-stage function, if provided post_stage(stage_output_parameters) config['global'] = global_parameters global_out = file_manager.create_file(prefix, None, _OUT_CONFIG_NAME, extension='.yml') write_config_file(global_out, config) return config
def deepblast(**kwargs) -> Dict[str, Any]: """Sequence-Sequence alignments with DeepBLAST DeepBLAST learned structural alignments from sequence https://github.com/flatironinstitute/deepblast https://www.biorxiv.org/content/10.1101/2020.11.03.365932v1 """ # TODO: Fix that logic before merging if "transferred_annotations_file" not in kwargs and "pairings_file" not in kwargs: raise MissingParameterError( "You need to specify either 'transferred_annotations_file' or 'pairings_file' for DeepBLAST" ) if "transferred_annotations_file" in kwargs and "pairings_file" in kwargs: raise InvalidParameterError( "You can't specify both 'transferred_annotations_file' and 'pairings_file' for DeepBLAST" ) result_kwargs = deepcopy(kwargs) file_manager = get_file_manager(**kwargs) # This stays below 8GB, so it should be a good default batch_size = result_kwargs.setdefault("batch_size", 50) if "device" in result_kwargs: device = torch.device(result_kwargs["device"]) if device.type != "cuda": raise RuntimeError( f"You can only run DeepBLAST on a CUDA-compatible GPU, not on {device.type}" ) else: if not torch.cuda.is_available(): raise RuntimeError( "DeepBLAST requires a CUDA-compatible GPU, but none was found") device = torch.device("cuda") mapping_file = read_mapping_file(result_kwargs["mapping_file"]) mapping = { str(remapped): original for remapped, original in mapping_file[["original_id"]].itertuples() } query_by_id = { mapping[entry.name]: str(entry.seq) for entry in SeqIO.parse(result_kwargs["remapped_sequences_file"], "fasta") } # You can either provide a set of pairing or use the output of k-nn with a fasta file for the reference embeddings if "pairings_file" in result_kwargs: pairings_file = read_csv(result_kwargs["pairings_file"]) pairings = list(pairings_file[["query", "target"]].itertuples(index=False)) target_by_id = query_by_id else: transferred_annotations_file = read_csv( result_kwargs["transferred_annotations_file"]) pairings = [] for _, row in transferred_annotations_file.iterrows(): query = row["original_id"] for target in row.filter(regex="k_nn_.*_identifier"): pairings.append((query, target)) target_by_id = {} for entry in read_fasta(result_kwargs["reference_fasta_file"]): target_by_id[entry.name] = str(entry.seq[:]) # Create one output file per query result_kwargs["alignment_files"] = dict() for query in set(i for i, _ in pairings): filename = file_manager.create_file( result_kwargs.get("prefix"), result_kwargs.get("stage_name"), f"{slugify(query, lowercase=False)}_alignments", extension=".a2m", ) result_kwargs["alignment_files"][query] = filename unknown_queries = set(list(zip(*pairings))[0]) - set(query_by_id.keys()) if unknown_queries: raise ValueError(f"Unknown query sequences: {unknown_queries}") unknown_targets = set(list(zip(*pairings))[1]) - set(target_by_id.keys()) if unknown_targets: raise ValueError(f"Unknown target sequences: {unknown_targets}") # Load the pretrained model if "model_file" not in result_kwargs: model_file = get_model_file("deepblast", "model_file") else: model_file = result_kwargs["model_file"] alignments = deepblast_align(pairings, query_by_id, target_by_id, model_file, device, batch_size) for query, alignments in itertools.groupby(alignments, key=lambda i: i[0]): _, targets, queries_aligned, targets_aligned = list(zip(*alignments)) padded_query, padded_targets = pairwise_alignments_to_msa( queries_aligned, targets_aligned) with open(result_kwargs["alignment_files"][query], "w") as fp: fp.write(f">{query}\n") fp.write(f"{padded_query}\n") for target, padded_target in zip(targets, padded_targets): fp.write(f">{target}\n") fp.write(f"{padded_target}\n") return result_kwargs
def mmseqs_search_protocol(**kwargs) -> Dict[str, Any]: # Check that mmseqs2 is installed if not check_mmseqs(): raise OSError( "mmseqs binary could not be found. Please make sure it's in your PATH. " "You can download mmseqs2 from: https://github.com/soedinglab/MMseqs2/releases/latest" ) result_kwargs = deepcopy(kwargs) file_manager = get_file_manager(**kwargs) # Set defaults result_kwargs.setdefault("convert_to_profiles", False) result_kwargs.setdefault("mmseqs_search_options", {}) # Build options (if this fails: no point in creating dbs!) mmseqs_search_options = result_kwargs.get('mmseqs_search_options') search_options = MMseqsSearchOptions() for search_option in mmseqs_search_options: option_enum = MMseqsSearchOptionsEnum.from_str(search_option) search_options.add_option(option_enum, mmseqs_search_options[search_option]) # Check that either search_sequences_file, # or search_sequence_directory (a mmseqs db), # or search_profiles_directory (a mmseqs db of a profile) is in kwargs. # Priority: search_profiles_directory > search_sequence_directory > search_sequences_file if not ("search_sequences_file" in kwargs or "search_sequences_directory" in kwargs or "search_profiles_directory" in kwargs): raise MissingParameterError( "You need to specify either 'search_sequences_file' (in FASTA format), 'search_sequences_directory'" " (a mmseqs database created with `mmseqs createdb ...`) or 'search_profiles_directory' " "(a mmseqs profile database created) after an `mmseqs search` and am `mmseqs result2profile`)." ) search_sequences_path = None if "search_profiles_directory" in kwargs: search_sequences_path = kwargs['search_profiles_directory'] if search_sequences_path is None and "search_sequences_directory" in kwargs: search_sequences_path = kwargs['search_sequences_directory'] if search_sequences_path is None and "search_sequences_file" in kwargs: search_sequences_directory = file_manager.create_directory( result_kwargs.get("prefix"), result_kwargs.get("stage_name"), "search_sequences_directory", ) create_mmseqs_database(kwargs['search_sequences_file'], Path(search_sequences_directory)) result_kwargs[ 'search_sequences_directory'] = search_sequences_directory search_sequences_path = search_sequences_directory query_sequences_path = None if "query_profiles_directory" in kwargs: query_sequences_path = kwargs['query_profiles_directory'] if query_sequences_path is None and "query_sequences_directory" in kwargs: query_sequences_path = kwargs['query_sequences_directory'] if query_sequences_path is None: query_sequences_directory = file_manager.create_directory( result_kwargs.get("prefix"), result_kwargs.get("stage_name"), "query_sequences_directory", ) create_mmseqs_database(kwargs['remapped_sequences_file'], Path(query_sequences_directory)) result_kwargs['query_sequences_directory'] = query_sequences_directory query_sequences_path = query_sequences_directory mmseqs_search_results_directory = file_manager.create_directory( result_kwargs.get("prefix"), result_kwargs.get("stage_name"), "mmseqs_search_results_directory", ) mmseqs_search(Path(query_sequences_path), Path(search_sequences_path), Path(mmseqs_search_results_directory), search_options) result_kwargs[ 'mmseqs_search_results_directory'] = mmseqs_search_results_directory if search_options.has_option(MMseqsSearchOptionsEnum.alignment_output): mmseqs_search_results_file = file_manager.create_file( result_kwargs.get("prefix"), result_kwargs.get("stage_name"), "alignment_results_file", extension=".tsv") convert_result_to_alignment_file(Path(query_sequences_path), Path(search_sequences_path), Path(mmseqs_search_results_directory), Path(mmseqs_search_results_file)) # Append header to TSV -- a stupid OP that requires reading each line of the file... header = "query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits," +\ "pident,nident,qlen,tlen,qcov,tcov,qaln,taln" with temporary_copy(mmseqs_search_results_file) as original,\ open(mmseqs_search_results_file, 'w') as out: out.write(header.replace(",", "\t") + "\n") for line in original: out.write(line.decode('utf-8')) result_kwargs['alignment_results_file'] = mmseqs_search_results_file if result_kwargs["convert_to_profiles"]: query_profiles_directory = file_manager.create_directory( result_kwargs.get("prefix"), result_kwargs.get("stage_name"), "query_profiles_directory", ) convert_mmseqs_result_to_profile(Path(query_sequences_path), Path(search_sequences_path), Path(mmseqs_search_results_directory), Path(query_profiles_directory)) result_kwargs['query_profiles_directory'] = query_profiles_directory return result_kwargs
def run(**kwargs): """BETA: in-silico mutagenesis using BertForMaskedLM optional (see extract stage for details): * model_directory * device * half_precision * half_precision_model * temperature: temperature for softmax """ required_kwargs = [ "protocol", "prefix", "stage_name", "remapped_sequences_file", "mapping_file", ] check_required(kwargs, required_kwargs) result_kwargs = deepcopy(kwargs) if result_kwargs["protocol"] not in _PROTOCOLS: raise RuntimeError( f"Passed protocol {result_kwargs['protocol']}, but allowed are: {', '.join(_PROTOCOLS)}" ) temperature = result_kwargs.setdefault("temperature", 1) device = get_device(result_kwargs.get("device")) model_class: Type[ProtTransBertBFDMutagenesis] = _PROTOCOLS[ result_kwargs["protocol"] ] model = model_class( device, result_kwargs.get("model_directory"), result_kwargs.get("half_precision_model"), ) file_manager = get_file_manager() file_manager.create_stage(result_kwargs["prefix"], result_kwargs["stage_name"]) # The mapping file contains the corresponding ids in the same order sequences = [ str(entry.seq) for entry in SeqIO.parse(result_kwargs["remapped_sequences_file"], "fasta") ] mapping_file = read_mapping_file(result_kwargs["mapping_file"]) probabilities_all = dict() with tqdm(total=int(mapping_file["sequence_length"].sum())) as progress_bar: for sequence_id, original_id, sequence in zip( mapping_file.index, mapping_file["original_id"], sequences ): with torch.no_grad(): probabilities = model.get_sequence_probabilities( sequence, temperature, progress_bar=progress_bar ) for p in probabilities: assert math.isclose( 1, (sum(p.values()) - p["position"]), rel_tol=1e-6 ), "softmax values should add up to 1" probabilities_all[sequence_id] = probabilities residue_probabilities = probabilities_as_dataframe( mapping_file, probabilities_all, sequences ) probabilities_file = file_manager.create_file( result_kwargs.get("prefix"), result_kwargs.get("stage_name"), "residue_probabilities_file", extension=".csv", ) residue_probabilities.to_csv(probabilities_file, index=False) result_kwargs["residue_probabilities_file"] = probabilities_file return result_kwargs