Пример #1
0
def run(**kwargs):
    """
    Run embedding protocol

    Parameters
    ----------
    kwargs arguments (* denotes optional):
        sequences_file: Where sequences live
        prefix: Output prefix for all generated files
        protocol: Which embedder to use
        mapping_file: the mapping file generated by the pipeline when remapping indexes
        stage_name: The stage name

    Returns
    -------
    Dictionary with results of stage
    """
    check_required(
        kwargs,
        [
            "protocol", "prefix", "stage_name", "remapped_sequences_file",
            "mapping_file"
        ],
    )

    if kwargs["protocol"] not in PROTOCOLS:
        raise InvalidParameterError(
            "Invalid protocol selection: {}. Valid protocols are: {}".format(
                kwargs["protocol"], ", ".join(PROTOCOLS.keys())))

    embedder_class = PROTOCOLS[kwargs["protocol"]]

    if embedder_class == UniRepEmbedder and kwargs.get("use_cpu") is not None:
        raise InvalidParameterError(
            "UniRep does not support configuring `use_cpu`")

    result_kwargs = deepcopy(kwargs)

    # Download necessary files if needed
    # noinspection PyProtectedMember
    for file in embedder_class._necessary_files:
        if not result_kwargs.get(file):
            result_kwargs[file] = get_model_file(model=embedder_class.name,
                                                 file=file)

    # noinspection PyProtectedMember
    for directory in embedder_class._necessary_directories:
        if not result_kwargs.get(directory):
            result_kwargs[directory] = get_model_directories_from_zip(
                model=embedder_class.name, directory=directory)

    result_kwargs.setdefault("max_amino_acids",
                             DEFAULT_MAX_AMINO_ACIDS[kwargs["protocol"]])

    file_manager = get_file_manager(**kwargs)
    embedder: EmbedderInterface = embedder_class(**result_kwargs)
    _check_transform_embeddings_function(embedder, result_kwargs)

    return embed_and_write_batched(embedder, file_manager, result_kwargs,
                                   kwargs.get("half_precision", False))
Пример #2
0
def run(**kwargs):
    """
    Run embedding protocol

    Parameters
    ----------
    kwargs arguments (* denotes optional):
        sequences_file: Where sequences live
        prefix: Output prefix for all generated files
        protocol: Which embedder to use
        mapping_file: the mapping file generated by the pipeline when remapping indexes
        stage_name: The stage name

    Returns
    -------
    Dictionary with results of stage
    """
    embedder_class, result_kwargs = prepare_kwargs(**kwargs)

    file_manager = get_file_manager(**kwargs)
    embedder: EmbedderInterface = embedder_class(**result_kwargs)
    _check_transform_embeddings_function(embedder, result_kwargs)

    return embed_and_write_batched(embedder, file_manager, result_kwargs,
                                   kwargs.get("half_precision", False))
Пример #3
0
def _process_fasta_file(**kwargs):
    """
    Will assign MD5 hash as ID if no if provided for a sequence.
    """
    result_kwargs = deepcopy(kwargs)
    file_manager = get_file_manager(**kwargs)

    sequences = read_fasta(kwargs['sequences_file'])
    sequences_file_path = file_manager.create_file(kwargs.get('prefix'),
                                                   None,
                                                   'sequences_file',
                                                   extension='.fasta')
    write_fasta_file(sequences, sequences_file_path)

    result_kwargs['sequences_file'] = sequences_file_path

    # Remap using sequence position rather than md5 hash -- not encouraged!
    result_kwargs['simple_remapping'] = result_kwargs.get(
        'simple_remapping', False)

    mapping = reindex_sequences(sequences,
                                simple=result_kwargs['simple_remapping'])

    # Check if there's the same MD5 index twice. This most likely indicates 100% sequence identity.
    # Throw an error for MD5 hash clashes!
    if mapping.index.has_duplicates:
        raise MD5ClashException(
            "There is at least one MD5 hash clash.\n"
            "This most likely indicates there are multiple identical sequences in your FASTA file.\n"
            "MD5 hashes are used to remap sequence identifiers from the input FASTA.\n"
            "This error exists to prevent wasting resources (computing the same embedding twice).\n"
            "There's a (very) low probability of this indicating a real MD5 clash.\n\n"
            "If you are sure there are no identical sequences in your set, please open an issue at "
            "https://github.com/sacdallago/bio_embeddings/issues . "
            "Otherwise, use cd-hit to reduce your input FASTA to exclude identical sequences!"
        )

    mapping_file_path = file_manager.create_file(kwargs.get('prefix'),
                                                 None,
                                                 'mapping_file',
                                                 extension='.csv')
    remapped_sequence_file_path = file_manager.create_file(
        kwargs.get('prefix'),
        None,
        'remapped_sequences_file',
        extension='.fasta')

    write_fasta_file(sequences, remapped_sequence_file_path)
    mapping.to_csv(mapping_file_path)

    result_kwargs['mapping_file'] = mapping_file_path
    result_kwargs['remapped_sequences_file'] = remapped_sequence_file_path

    return result_kwargs
Пример #4
0
def run(**kwargs):
    """
    Run project protocol

    Parameters
    ----------
    kwargs arguments (* denotes optional):
        projected_reduced_embeddings_file or projected_embeddings_file or reduced_embeddings_file: Where per-protein embeddings live
        prefix: Output prefix for all generated files
        stage_name: The stage name
        protocol: Which projection technique to use
        mapping_file: the mapping file generated by the pipeline when remapping indexes

    Returns
    -------
    Dictionary with results of stage
    """
    check_required(kwargs, ["protocol", "prefix", "stage_name", "mapping_file"])

    if kwargs["protocol"] not in PROTOCOLS:
        raise InvalidParameterError(
            "Invalid protocol selection: "
            + "{}. Valid protocols are: {}".format(
                kwargs["protocol"], ", ".join(PROTOCOLS.keys())
            )
        )

    result_kwargs = deepcopy(kwargs)

    # We want to allow chaining protocols, e.g. first tucker than umap,
    # so we need to allow projected embeddings as input
    embeddings_input_file = (
        kwargs.get("projected_reduced_embeddings_file")
        or kwargs.get("projected_embeddings_file")
        or kwargs.get("reduced_embeddings_file")
    )
    if not embeddings_input_file:
        raise InvalidParameterError(
            f"You need to provide either projected_reduced_embeddings_file or projected_embeddings_file or "
            f"reduced_embeddings_file for {kwargs['protocol']}"
        )
    result_kwargs["reduced_embeddings_file"] = embeddings_input_file

    file_manager = get_file_manager(**kwargs)

    result_kwargs = PROTOCOLS[kwargs["protocol"]](file_manager, result_kwargs)

    return result_kwargs
Пример #5
0
def tsne(**kwargs):
    result_kwargs = deepcopy(kwargs)
    file_manager = get_file_manager(**kwargs)

    # Get sequence mapping to use as information source
    mapping = read_csv(result_kwargs['mapping_file'], index_col=0)

    reduced_embeddings_file_path = result_kwargs['reduced_embeddings_file']

    reduced_embeddings = []

    with h5py.File(reduced_embeddings_file_path, 'r') as f:
        for remapped_id in mapping.index:
            reduced_embeddings.append(np.array(f[str(remapped_id)]))

    # Get parameters or set defaults
    result_kwargs['perplexity'] = kwargs.get('perplexity', 6)
    result_kwargs['n_jobs'] = kwargs.get('n_jobs', -1)
    result_kwargs['n_iter'] = kwargs.get('n_iter', 15000)

    result_kwargs['metric'] = kwargs.get('metric', 'cosine')
    result_kwargs['n_components'] = kwargs.get('n_components', 3)
    result_kwargs['random_state'] = kwargs.get('random_state', 420)
    result_kwargs['verbose'] = kwargs.get('verbose', 1)

    projected_embeddings = tsne_reduce(reduced_embeddings, **kwargs)

    for i in range(result_kwargs['n_components']):
        mapping[f'component_{i}'] = projected_embeddings[:, i]

    projected_embeddings_file_path = file_manager.create_file(
        kwargs.get('prefix'),
        result_kwargs.get('stage_name'),
        'projected_embeddings_file',
        extension='.csv')

    mapping.to_csv(projected_embeddings_file_path)
    result_kwargs['projected_embeddings_file'] = projected_embeddings_file_path

    return result_kwargs
Пример #6
0
def run(**kwargs):
    """
    Run embedding protocol

    Parameters
    ----------
    kwargs arguments (* denotes optional):
        sequences_file: Where sequences live
        prefix: Output prefix for all generated files
        protocol: Which embedder to use
        mapping_file: the mapping file generated by the pipeline when remapping indexes
        stage_name: The stage name

    Returns
    -------
    Dictionary with results of stage
    """
    embedder_class, result_kwargs = prepare_kwargs(**kwargs)

    # Download necessary files if needed
    # noinspection PyProtectedMember
    for file in embedder_class.necessary_files:
        if not result_kwargs.get(file):
            result_kwargs[file] = get_model_file(model=embedder_class.name,
                                                 file=file)

    # noinspection PyProtectedMember
    for directory in embedder_class.necessary_directories:
        if not result_kwargs.get(directory):
            result_kwargs[directory] = get_model_directories_from_zip(
                model=embedder_class.name, directory=directory)

    file_manager = get_file_manager(**kwargs)
    embedder: EmbedderInterface = embedder_class(**result_kwargs)
    _check_transform_embeddings_function(embedder, result_kwargs)

    return embed_and_write_batched(embedder, file_manager, result_kwargs,
                                   kwargs.get("half_precision", False))
Пример #7
0
def plot_mutagenesis(result_kwargs):
    """BETA: visualize in-silico mutagenesis as a heatmap with plotly

    mandatory:
    * residue_probabilities_file
    """
    required_kwargs = [
        "protocol",
        "prefix",
        "stage_name",
        "residue_probabilities_file",
    ]
    check_required(result_kwargs, required_kwargs)
    file_manager = get_file_manager()
    file_manager.create_stage(result_kwargs["prefix"], result_kwargs["stage_name"])

    probabilities_all = pandas.read_csv(result_kwargs["residue_probabilities_file"])
    assert (
        list(probabilities_all.columns) == PROBABILITIES_COLUMNS
    ), f"probabilities file is expected to have the following columns: {PROBABILITIES_COLUMNS}"
    number_of_proteins = len(set(probabilities_all["id"]))

    for sequence_id, probabilities in tqdm(
        probabilities_all.groupby("id"), total=number_of_proteins
    ):
        fig = plot(probabilities)
        plotly.offline.plot(
            fig,
            filename=file_manager.create_file(
                result_kwargs.get("prefix"),
                result_kwargs.get("stage_name"),
                sequence_id,
                extension=".html",
            ),
        )

    return result_kwargs
Пример #8
0
def _process_fasta_file(**kwargs):
    """
    Will assign MD5 hash as ID if no if provided for a sequence.
    """
    result_kwargs = deepcopy(kwargs)
    file_manager = get_file_manager(**kwargs)

    sequences = read_fasta(kwargs['sequences_file'])

    # Sanity check the fasta file to avoid nonsense and/or crashes by the embedders
    letters = set(string.ascii_letters)
    for entry in sequences:
        illegal = sorted(set(entry.seq) - letters)
        if illegal:
            formatted = "'" + "', '".join(illegal) + "'"
            raise ValueError(
                f"The entry '{entry.name}' in {kwargs['sequences_file']} contains the characters {formatted}, "
                f"while only single letter code is allowed "
                f"(https://en.wikipedia.org/wiki/Amino_acid#Table_of_standard_amino_acid_abbreviations_and_properties)."
            )
        # This is a warning due to the inconsistent handling between different embedders
        if not str(entry.seq).isupper():
            logger.warning(
                f"The entry '{entry.name}' in {kwargs['sequences_file']} contains lower case amino acids. "
                f"Lower case letters are uninterpretable by most language models, "
                f"and their embedding will be nonesensical. "
                f"Protein LMs available through bio_embeddings have been trained on upper case, "
                f"single letter code sequence representations only "
                f"(https://en.wikipedia.org/wiki/Amino_acid#Table_of_standard_amino_acid_abbreviations_and_properties)."
            )

    sequences_file_path = file_manager.create_file(kwargs.get('prefix'),
                                                   None,
                                                   'sequences_file',
                                                   extension='.fasta')
    write_fasta_file(sequences, sequences_file_path)

    result_kwargs['sequences_file'] = sequences_file_path

    # Remap using sequence position rather than md5 hash -- not encouraged!
    result_kwargs['simple_remapping'] = result_kwargs.get(
        'simple_remapping', False)

    mapping = reindex_sequences(sequences,
                                simple=result_kwargs['simple_remapping'])

    # Check if there's the same MD5 index twice. This most likely indicates 100% sequence identity.
    # Throw an error for MD5 hash clashes!
    if mapping.index.has_duplicates:
        raise MD5ClashException(
            "There is at least one MD5 hash clash.\n"
            "This most likely indicates there are multiple identical sequences in your FASTA file.\n"
            "MD5 hashes are used to remap sequence identifiers from the input FASTA.\n"
            "This error exists to prevent wasting resources (computing the same embedding twice).\n"
            "There's a (very) low probability of this indicating a real MD5 clash.\n\n"
            "If you are sure there are no identical sequences in your set, please open an issue at "
            "https://github.com/sacdallago/bio_embeddings/issues . "
            "Otherwise, use cd-hit to reduce your input FASTA to exclude identical sequences!"
        )

    mapping_file_path = file_manager.create_file(kwargs.get('prefix'),
                                                 None,
                                                 'mapping_file',
                                                 extension='.csv')
    remapped_sequence_file_path = file_manager.create_file(
        kwargs.get('prefix'),
        None,
        'remapped_sequences_file',
        extension='.fasta')

    write_fasta_file(sequences, remapped_sequence_file_path)
    mapping.to_csv(mapping_file_path)

    result_kwargs['mapping_file'] = mapping_file_path
    result_kwargs['remapped_sequences_file'] = remapped_sequence_file_path

    return result_kwargs
Пример #9
0
def execute_pipeline_from_config(config: Dict,
                                 post_stage: Callable[[Dict],
                                                      None] = _null_function,
                                 **kwargs) -> Dict:
    original_config = deepcopy(config)

    check_required(config, ["global"])

    # !! pop = remove from config!
    global_parameters = config.pop('global')

    check_required(global_parameters, ["prefix", "sequences_file"])

    file_manager = get_file_manager(**global_parameters)

    # Make sure prefix exists
    prefix = global_parameters['prefix']

    # If prefix already exists
    if file_manager.exists(prefix):
        if not kwargs.get('overwrite'):
            raise FileExistsError(
                "The prefix already exists & no overwrite option has been set.\n"
                "Either set --overwrite, or move data from the prefix.\n"
                "Prefix: {}".format(prefix))
    else:
        # create the prefix
        file_manager.create_prefix(prefix)

    # Copy original config to prefix
    global_in = file_manager.create_file(prefix,
                                         None,
                                         _IN_CONFIG_NAME,
                                         extension='.yml')
    write_config_file(global_in, original_config)

    # This downloads sequences_file if required
    download_files_for_stage(global_parameters, file_manager, prefix)

    global_parameters = _process_fasta_file(**global_parameters)

    for stage_name in config:
        stage_parameters = config[stage_name]
        original_stage_parameters = dict(**stage_parameters)

        check_required(stage_parameters, ["protocol", "type"])

        stage_type = stage_parameters['type']
        stage_runnable = _STAGES.get(stage_type)

        if not stage_runnable:
            raise Exception(
                "No type defined, or invalid stage type defined: {}".format(
                    stage_type))

        # Prepare to run stage
        stage_parameters['stage_name'] = stage_name
        file_manager.create_stage(prefix, stage_name)

        stage_parameters = download_files_for_stage(stage_parameters,
                                                    file_manager, prefix,
                                                    stage_name)

        stage_dependency = stage_parameters.get('depends_on')

        if stage_dependency:
            if stage_dependency not in config:
                raise Exception(
                    "Stage {} depends on {}, but dependency not found in config."
                    .format(stage_name, stage_dependency))

            stage_dependency_parameters = config.get(stage_dependency)
            stage_parameters = {
                **global_parameters,
                **stage_dependency_parameters,
                **stage_parameters
            }
        else:
            stage_parameters = {**global_parameters, **stage_parameters}

        # Register start time
        start_time = datetime.now().astimezone()
        stage_parameters['start_time'] = str(start_time)

        stage_in = file_manager.create_file(prefix,
                                            stage_name,
                                            _IN_CONFIG_NAME,
                                            extension='.yml')
        write_config_file(stage_in, stage_parameters)

        try:
            stage_output_parameters = stage_runnable(**stage_parameters)
        except Exception as e:
            # Tell the user which stage failed and show an url to report an error on github
            try:
                version = importlib_metadata.version("bio_embeddings")
            except PackageNotFoundError:
                version = "unknown"

            # Make a github flavored markdown table; the header is in the template
            parameter_table = "\n".join(
                f"{key}|{value}"
                for key, value in original_stage_parameters.items())
            params = {
                # https://stackoverflow.com/a/35498685/3549270
                "title":
                f"Protocol {original_stage_parameters['protocol']}: {type(e).__name__}: {e}",
                "body":
                _ERROR_REPORTING_TEMPLATE.format(
                    version,
                    torch.cuda.is_available(),
                    parameter_table,
                    traceback.format_exc(10),
                ),
            }
            print(traceback.format_exc(), file=sys.stderr)
            print(
                f"Consider reporting this error at this url: {_ISSUE_URL}?{urllib.parse.urlencode(params)}\n\n"
                f"Stage {stage_name} failed.",
                file=sys.stderr,
            )

            sys.exit(1)

        # Register end time
        end_time = datetime.now().astimezone()
        stage_output_parameters['end_time'] = str(end_time)

        # Register elapsed time
        stage_output_parameters['elapsed_time'] = str(end_time - start_time)

        stage_out = file_manager.create_file(prefix,
                                             stage_name,
                                             _OUT_CONFIG_NAME,
                                             extension='.yml')
        write_config_file(stage_out, stage_output_parameters)

        # Store in global_out config for later retrieval (e.g. depends_on)
        config[stage_name] = stage_output_parameters

        # Execute post-stage function, if provided
        post_stage(stage_output_parameters)

    config['global'] = global_parameters

    try:
        config['global']['version'] = importlib_metadata.version(
            "bio_embeddings")
    except PackageNotFoundError:
        pass  # :(

    global_out = file_manager.create_file(prefix,
                                          None,
                                          _OUT_CONFIG_NAME,
                                          extension='.yml')
    write_config_file(global_out, config)

    return config
Пример #10
0
def plotly(**kwargs):
    result_kwargs = deepcopy(kwargs)
    file_manager = get_file_manager(**kwargs)

    # 2 or 3D plot? Usually, this is directly fetched from "project" stage via "depends_on"
    result_kwargs['n_components'] = kwargs.get('n_components', 3)
    if result_kwargs['n_components'] < 2:
        raise TooFewComponentsException(f"n_components is set to {result_kwargs['n_components']}. It should be >1.\n"
                                        f"If set to 2, will render 2D scatter plot.\n"
                                        f"If set to >3, will render 3D scatter plot.")

    # Get projected_embeddings_file containing x,y,z coordinates and identifiers
    projected_embeddings_file = read_csv(result_kwargs['projected_embeddings_file'], index_col=0)

    if result_kwargs.get('annotation_file'):

        annotation_file = read_csv(result_kwargs['annotation_file']).set_index('identifier')

        # Save a copy of the annotation file with index set to identifier
        input_annotation_file_path = file_manager.create_file(kwargs.get('prefix'),
                                                              result_kwargs.get('stage_name'),
                                                              'input_annotation_file',
                                                              extension='.csv')

        annotation_file.to_csv(input_annotation_file_path)

        # Merge annotation file and projected embedding file based on index or original id?
        result_kwargs['merge_via_index'] = result_kwargs.get('merge_via_index', False)

        # Display proteins with unknown annotation?
        result_kwargs['display_unknown'] = result_kwargs.get('display_unknown', True)

        if result_kwargs['merge_via_index']:
            if result_kwargs['display_unknown']:
                merged_annotation_file = annotation_file.join(projected_embeddings_file, how="outer")
                merged_annotation_file['label'].fillna('UNKNOWN', inplace=True)
            else:
                merged_annotation_file = annotation_file.join(projected_embeddings_file)
        else:
            if result_kwargs['display_unknown']:
                merged_annotation_file = annotation_file.join(projected_embeddings_file.set_index('original_id'), how="outer")
                merged_annotation_file['label'].fillna('UNKNOWN', inplace=True)
            else:
                merged_annotation_file = annotation_file.join(projected_embeddings_file.set_index('original_id'))

        merged_annotation_file_path = file_manager.create_file(kwargs.get('prefix'),
                                                               result_kwargs.get('stage_name'),
                                                               'merged_annotation_file',
                                                               extension='.csv')

        merged_annotation_file.to_csv(merged_annotation_file_path)
        result_kwargs['merged_annotation_file'] = merged_annotation_file_path

    if result_kwargs['n_components'] == 2:
        figure = render_scatter_plotly(merged_annotation_file)
    else:
        figure = render_3D_scatter_plotly(merged_annotation_file)

    plot_file_path = file_manager.create_file(kwargs.get('prefix'),
                                              result_kwargs.get('stage_name'),
                                              'plot_file',
                                              extension='.html')

    save_plotly_figure_to_html(figure, plot_file_path)
    result_kwargs['plot_file'] = plot_file_path

    return result_kwargs
Пример #11
0
def plotly(result_kwargs: Dict[str, Any]) -> Dict[str, Any]:
    file_manager = get_file_manager(**result_kwargs)

    # 2 or 3D plot? Usually, this is directly fetched from "project" stage via "depends_on"
    result_kwargs['n_components'] = result_kwargs.get('n_components', 3)
    if result_kwargs['n_components'] < 2:
        raise TooFewComponentsException(f"n_components is set to {result_kwargs['n_components']}. It should be >1.\n"
                                        f"If set to 2, will render 2D scatter plot.\n"
                                        f"If set to >3, will render 3D scatter plot.")
    # Get projected_embeddings_file containing x,y,z coordinates and identifiers
    suffix = Path(result_kwargs["projected_reduced_embeddings_file"]).suffix
    if suffix == ".csv":
        # Support the legacy csv format
        merged_annotation_file = read_csv(
            result_kwargs["projected_reduced_embeddings_file"], index_col=0
        )
    elif suffix == ".h5":
        # convert h5 to dataframe with ids and one column per dimension
        rows = []
        with h5py.File(result_kwargs["projected_reduced_embeddings_file"], "r") as file:
            for sequence_id, embedding in file.items():
                if embedding.shape != (3,):
                    raise RuntimeError(
                        f"Expected embeddings in projected_reduced_embeddings_file "
                        f"to be of shape (3,), not {embedding.shape}"
                    )
                row = (
                    sequence_id,
                    embedding.attrs["original_id"],
                    embedding[0],
                    embedding[1],
                    embedding[2],
                )
                rows.append(row)
        columns = [
            "sequence_id",
            "original_id",
            "component_0",
            "component_1",
            "component_2",
        ]
        merged_annotation_file = DataFrame.from_records(
            rows, index="sequence_id", columns=columns
        )
    else:
        raise InvalidParameterError(
            f"Expected .csv or .h5 as suffix for projected_reduced_embeddings_file, got {suffix}"
        )

    if result_kwargs.get('annotation_file'):

        annotation_file = read_csv(result_kwargs['annotation_file']).set_index('identifier')

        # Save a copy of the annotation file with index set to identifier
        input_annotation_file_path = file_manager.create_file(result_kwargs.get('prefix'),
                                                              result_kwargs.get('stage_name'),
                                                              'input_annotation_file',
                                                              extension='.csv')

        annotation_file.to_csv(input_annotation_file_path)

        # Merge annotation file and projected embedding file based on index or original id?
        result_kwargs['merge_via_index'] = result_kwargs.get('merge_via_index', False)

        # Display proteins with unknown annotation?
        result_kwargs['display_unknown'] = result_kwargs.get('display_unknown', True)

        if result_kwargs['merge_via_index']:
            if result_kwargs['display_unknown']:
                merged_annotation_file = annotation_file.join(merged_annotation_file, how="outer")
                merged_annotation_file['label'].fillna('UNKNOWN', inplace=True)
            else:
                merged_annotation_file = annotation_file.join(merged_annotation_file)
        else:
            if result_kwargs['display_unknown']:
                merged_annotation_file = annotation_file.join(merged_annotation_file.set_index('original_id'),
                                                              how="outer")
                merged_annotation_file['label'].fillna('UNKNOWN', inplace=True)
            else:
                merged_annotation_file = annotation_file.join(merged_annotation_file.set_index('original_id'))
    else:
        merged_annotation_file['label'] = 'UNKNOWN'

    merged_annotation_file_path = file_manager.create_file(result_kwargs.get('prefix'),
                                                           result_kwargs.get('stage_name'),
                                                           'merged_annotation_file',
                                                           extension='.csv')

    merged_annotation_file.to_csv(merged_annotation_file_path)
    result_kwargs['merged_annotation_file'] = merged_annotation_file_path

    if result_kwargs['n_components'] == 2:
        figure = render_scatter_plotly(merged_annotation_file)
    else:
        figure = render_3D_scatter_plotly(merged_annotation_file)

    plot_file_path = file_manager.create_file(result_kwargs.get('prefix'),
                                              result_kwargs.get('stage_name'),
                                              'plot_file',
                                              extension='.html')

    save_plotly_figure_to_html(figure, plot_file_path)
    result_kwargs['plot_file'] = plot_file_path

    return result_kwargs
Пример #12
0
def execute_pipeline_from_config(config: Dict,
                                 post_stage: Callable[[Dict],
                                                      None] = _null_function,
                                 **kwargs) -> Dict:
    original_config = deepcopy(config)

    check_required(config, ["global"])

    # !! pop = remove from config!
    global_parameters = config.pop('global')

    check_required(global_parameters, ["prefix", "sequences_file"])

    file_manager = get_file_manager(**global_parameters)

    # Make sure prefix exists
    prefix = global_parameters['prefix']

    # If prefix already exists
    if file_manager.exists(prefix):
        if not kwargs.get('overwrite'):
            raise FileExistsError(
                "The prefix already exists & no overwrite option has been set.\n"
                "Either set --overwrite, or move data from the prefix.\n"
                "Prefix: {}".format(prefix))
    else:
        # create the prefix
        file_manager.create_prefix(prefix)

    try:
        Path(prefix).joinpath("bio_embeddings_version.txt").write_text(
            importlib_metadata.version("bio_embeddings"))
    except PackageNotFoundError:
        pass  # :(

    # Copy original config to prefix
    global_in = file_manager.create_file(prefix,
                                         None,
                                         _IN_CONFIG_NAME,
                                         extension='.yml')
    write_config_file(global_in, original_config)

    global_parameters = _process_fasta_file(**global_parameters)

    for stage_name in config:
        stage_parameters = config[stage_name]

        check_required(stage_parameters, ["protocol", "type"])

        stage_type = stage_parameters['type']
        stage_runnable = _STAGES.get(stage_type)

        if not stage_runnable:
            raise Exception(
                "No type defined, or invalid stage type defined: {}".format(
                    stage_type))

        # Prepare to run stage
        stage_parameters['stage_name'] = stage_name
        file_manager.create_stage(prefix, stage_name)

        stage_dependency = stage_parameters.get('depends_on')

        if stage_dependency:
            if stage_dependency not in config:
                raise Exception(
                    "Stage {} depends on {}, but dependency not found in config."
                    .format(stage_name, stage_dependency))

            stage_dependency_parameters = config.get(stage_dependency)
            stage_parameters = {
                **global_parameters,
                **stage_dependency_parameters,
                **stage_parameters
            }
        else:
            stage_parameters = {**global_parameters, **stage_parameters}

        # Register start time
        start_time = datetime.now().astimezone()
        stage_parameters['start_time'] = str(start_time)

        stage_in = file_manager.create_file(prefix,
                                            stage_name,
                                            _IN_CONFIG_NAME,
                                            extension='.yml')
        write_config_file(stage_in, stage_parameters)

        stage_output_parameters = stage_runnable(**stage_parameters)

        # Register end time
        end_time = datetime.now().astimezone()
        stage_output_parameters['end_time'] = str(end_time)

        # Register elapsed time
        stage_output_parameters['elapsed_time'] = str(end_time - start_time)

        stage_out = file_manager.create_file(prefix,
                                             stage_name,
                                             _OUT_CONFIG_NAME,
                                             extension='.yml')
        write_config_file(stage_out, stage_output_parameters)

        # Store in global_out config for later retrieval (e.g. depends_on)
        config[stage_name] = stage_output_parameters

        # Execute post-stage function, if provided
        post_stage(stage_output_parameters)

    config['global'] = global_parameters
    global_out = file_manager.create_file(prefix,
                                          None,
                                          _OUT_CONFIG_NAME,
                                          extension='.yml')
    write_config_file(global_out, config)

    return config
Пример #13
0
def deepblast(**kwargs) -> Dict[str, Any]:
    """Sequence-Sequence alignments with DeepBLAST

    DeepBLAST learned structural alignments from sequence

    https://github.com/flatironinstitute/deepblast

    https://www.biorxiv.org/content/10.1101/2020.11.03.365932v1
    """
    # TODO: Fix that logic before merging
    if "transferred_annotations_file" not in kwargs and "pairings_file" not in kwargs:
        raise MissingParameterError(
            "You need to specify either 'transferred_annotations_file' or 'pairings_file' for DeepBLAST"
        )
    if "transferred_annotations_file" in kwargs and "pairings_file" in kwargs:
        raise InvalidParameterError(
            "You can't specify both 'transferred_annotations_file' and 'pairings_file' for DeepBLAST"
        )
    result_kwargs = deepcopy(kwargs)
    file_manager = get_file_manager(**kwargs)

    # This stays below 8GB, so it should be a good default
    batch_size = result_kwargs.setdefault("batch_size", 50)

    if "device" in result_kwargs:
        device = torch.device(result_kwargs["device"])
        if device.type != "cuda":
            raise RuntimeError(
                f"You can only run DeepBLAST on a CUDA-compatible GPU, not on {device.type}"
            )
    else:
        if not torch.cuda.is_available():
            raise RuntimeError(
                "DeepBLAST requires a CUDA-compatible GPU, but none was found")
        device = torch.device("cuda")

    mapping_file = read_mapping_file(result_kwargs["mapping_file"])
    mapping = {
        str(remapped): original
        for remapped, original in mapping_file[["original_id"]].itertuples()
    }

    query_by_id = {
        mapping[entry.name]: str(entry.seq)
        for entry in SeqIO.parse(result_kwargs["remapped_sequences_file"],
                                 "fasta")
    }

    # You can either provide a set of pairing or use the output of k-nn with a fasta file for the reference embeddings
    if "pairings_file" in result_kwargs:
        pairings_file = read_csv(result_kwargs["pairings_file"])
        pairings = list(pairings_file[["query",
                                       "target"]].itertuples(index=False))
        target_by_id = query_by_id
    else:
        transferred_annotations_file = read_csv(
            result_kwargs["transferred_annotations_file"])
        pairings = []
        for _, row in transferred_annotations_file.iterrows():
            query = row["original_id"]
            for target in row.filter(regex="k_nn_.*_identifier"):
                pairings.append((query, target))

        target_by_id = {}
        for entry in read_fasta(result_kwargs["reference_fasta_file"]):
            target_by_id[entry.name] = str(entry.seq[:])

    # Create one output file per query
    result_kwargs["alignment_files"] = dict()
    for query in set(i for i, _ in pairings):
        filename = file_manager.create_file(
            result_kwargs.get("prefix"),
            result_kwargs.get("stage_name"),
            f"{slugify(query, lowercase=False)}_alignments",
            extension=".a2m",
        )
        result_kwargs["alignment_files"][query] = filename

    unknown_queries = set(list(zip(*pairings))[0]) - set(query_by_id.keys())
    if unknown_queries:
        raise ValueError(f"Unknown query sequences: {unknown_queries}")

    unknown_targets = set(list(zip(*pairings))[1]) - set(target_by_id.keys())
    if unknown_targets:
        raise ValueError(f"Unknown target sequences: {unknown_targets}")

    # Load the pretrained model
    if "model_file" not in result_kwargs:
        model_file = get_model_file("deepblast", "model_file")
    else:
        model_file = result_kwargs["model_file"]

    alignments = deepblast_align(pairings, query_by_id, target_by_id,
                                 model_file, device, batch_size)

    for query, alignments in itertools.groupby(alignments, key=lambda i: i[0]):
        _, targets, queries_aligned, targets_aligned = list(zip(*alignments))
        padded_query, padded_targets = pairwise_alignments_to_msa(
            queries_aligned, targets_aligned)
        with open(result_kwargs["alignment_files"][query], "w") as fp:
            fp.write(f">{query}\n")
            fp.write(f"{padded_query}\n")
            for target, padded_target in zip(targets, padded_targets):
                fp.write(f">{target}\n")
                fp.write(f"{padded_target}\n")

    return result_kwargs
Пример #14
0
def mmseqs_search_protocol(**kwargs) -> Dict[str, Any]:
    # Check that mmseqs2 is installed
    if not check_mmseqs():
        raise OSError(
            "mmseqs binary could not be found. Please make sure it's in your PATH. "
            "You can download mmseqs2 from: https://github.com/soedinglab/MMseqs2/releases/latest"
        )

    result_kwargs = deepcopy(kwargs)
    file_manager = get_file_manager(**kwargs)

    # Set defaults
    result_kwargs.setdefault("convert_to_profiles", False)
    result_kwargs.setdefault("mmseqs_search_options", {})

    # Build options (if this fails: no point in creating dbs!)
    mmseqs_search_options = result_kwargs.get('mmseqs_search_options')
    search_options = MMseqsSearchOptions()

    for search_option in mmseqs_search_options:
        option_enum = MMseqsSearchOptionsEnum.from_str(search_option)
        search_options.add_option(option_enum,
                                  mmseqs_search_options[search_option])

    # Check that either search_sequences_file,
    # or search_sequence_directory (a mmseqs db),
    # or search_profiles_directory (a mmseqs db of a profile) is in kwargs.
    # Priority: search_profiles_directory > search_sequence_directory > search_sequences_file
    if not ("search_sequences_file" in kwargs or "search_sequences_directory"
            in kwargs or "search_profiles_directory" in kwargs):
        raise MissingParameterError(
            "You need to specify either 'search_sequences_file' (in FASTA format), 'search_sequences_directory'"
            " (a mmseqs database created with `mmseqs createdb ...`) or 'search_profiles_directory' "
            "(a mmseqs profile database created) after an `mmseqs search` and am `mmseqs result2profile`)."
        )

    search_sequences_path = None

    if "search_profiles_directory" in kwargs:
        search_sequences_path = kwargs['search_profiles_directory']

    if search_sequences_path is None and "search_sequences_directory" in kwargs:
        search_sequences_path = kwargs['search_sequences_directory']

    if search_sequences_path is None and "search_sequences_file" in kwargs:
        search_sequences_directory = file_manager.create_directory(
            result_kwargs.get("prefix"),
            result_kwargs.get("stage_name"),
            "search_sequences_directory",
        )

        create_mmseqs_database(kwargs['search_sequences_file'],
                               Path(search_sequences_directory))
        result_kwargs[
            'search_sequences_directory'] = search_sequences_directory
        search_sequences_path = search_sequences_directory

    query_sequences_path = None

    if "query_profiles_directory" in kwargs:
        query_sequences_path = kwargs['query_profiles_directory']

    if query_sequences_path is None and "query_sequences_directory" in kwargs:
        query_sequences_path = kwargs['query_sequences_directory']

    if query_sequences_path is None:
        query_sequences_directory = file_manager.create_directory(
            result_kwargs.get("prefix"),
            result_kwargs.get("stage_name"),
            "query_sequences_directory",
        )

        create_mmseqs_database(kwargs['remapped_sequences_file'],
                               Path(query_sequences_directory))
        result_kwargs['query_sequences_directory'] = query_sequences_directory
        query_sequences_path = query_sequences_directory

    mmseqs_search_results_directory = file_manager.create_directory(
        result_kwargs.get("prefix"),
        result_kwargs.get("stage_name"),
        "mmseqs_search_results_directory",
    )

    mmseqs_search(Path(query_sequences_path), Path(search_sequences_path),
                  Path(mmseqs_search_results_directory), search_options)

    result_kwargs[
        'mmseqs_search_results_directory'] = mmseqs_search_results_directory

    if search_options.has_option(MMseqsSearchOptionsEnum.alignment_output):
        mmseqs_search_results_file = file_manager.create_file(
            result_kwargs.get("prefix"),
            result_kwargs.get("stage_name"),
            "alignment_results_file",
            extension=".tsv")

        convert_result_to_alignment_file(Path(query_sequences_path),
                                         Path(search_sequences_path),
                                         Path(mmseqs_search_results_directory),
                                         Path(mmseqs_search_results_file))

        # Append header to TSV -- a stupid OP that requires reading each line of the file...
        header = "query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits," +\
                 "pident,nident,qlen,tlen,qcov,tcov,qaln,taln"

        with temporary_copy(mmseqs_search_results_file) as original,\
                open(mmseqs_search_results_file, 'w') as out:
            out.write(header.replace(",", "\t") + "\n")
            for line in original:
                out.write(line.decode('utf-8'))

        result_kwargs['alignment_results_file'] = mmseqs_search_results_file

    if result_kwargs["convert_to_profiles"]:
        query_profiles_directory = file_manager.create_directory(
            result_kwargs.get("prefix"),
            result_kwargs.get("stage_name"),
            "query_profiles_directory",
        )

        convert_mmseqs_result_to_profile(Path(query_sequences_path),
                                         Path(search_sequences_path),
                                         Path(mmseqs_search_results_directory),
                                         Path(query_profiles_directory))

        result_kwargs['query_profiles_directory'] = query_profiles_directory

    return result_kwargs
Пример #15
0
def run(**kwargs):
    """BETA: in-silico mutagenesis using BertForMaskedLM

    optional (see extract stage for details):
     * model_directory
     * device
     * half_precision
     * half_precision_model
     * temperature: temperature for softmax
    """
    required_kwargs = [
        "protocol",
        "prefix",
        "stage_name",
        "remapped_sequences_file",
        "mapping_file",
    ]
    check_required(kwargs, required_kwargs)
    result_kwargs = deepcopy(kwargs)
    if result_kwargs["protocol"] not in _PROTOCOLS:
        raise RuntimeError(
            f"Passed protocol {result_kwargs['protocol']}, but allowed are: {', '.join(_PROTOCOLS)}"
        )
    temperature = result_kwargs.setdefault("temperature", 1)
    device = get_device(result_kwargs.get("device"))
    model_class: Type[ProtTransBertBFDMutagenesis] = _PROTOCOLS[
        result_kwargs["protocol"]
    ]
    model = model_class(
        device,
        result_kwargs.get("model_directory"),
        result_kwargs.get("half_precision_model"),
    )

    file_manager = get_file_manager()
    file_manager.create_stage(result_kwargs["prefix"], result_kwargs["stage_name"])

    # The mapping file contains the corresponding ids in the same order
    sequences = [
        str(entry.seq)
        for entry in SeqIO.parse(result_kwargs["remapped_sequences_file"], "fasta")
    ]
    mapping_file = read_mapping_file(result_kwargs["mapping_file"])

    probabilities_all = dict()
    with tqdm(total=int(mapping_file["sequence_length"].sum())) as progress_bar:
        for sequence_id, original_id, sequence in zip(
            mapping_file.index, mapping_file["original_id"], sequences
        ):
            with torch.no_grad():
                probabilities = model.get_sequence_probabilities(
                    sequence, temperature, progress_bar=progress_bar
                )

            for p in probabilities:
                assert math.isclose(
                    1, (sum(p.values()) - p["position"]), rel_tol=1e-6
                ), "softmax values should add up to 1"

            probabilities_all[sequence_id] = probabilities
    residue_probabilities = probabilities_as_dataframe(
        mapping_file, probabilities_all, sequences
    )

    probabilities_file = file_manager.create_file(
        result_kwargs.get("prefix"),
        result_kwargs.get("stage_name"),
        "residue_probabilities_file",
        extension=".csv",
    )
    residue_probabilities.to_csv(probabilities_file, index=False)
    result_kwargs["residue_probabilities_file"] = probabilities_file
    return result_kwargs