示例#1
0
def _check_transform_embeddings_function(embedder: EmbedderInterface, result_kwargs: Dict[str, Any]):

    result_kwargs.setdefault("embeddings_transformer_function", None)

    if result_kwargs["embeddings_transformer_function"] is not None:
        try:
            transform_function = eval(result_kwargs["embeddings_transformer_function"], {}, {"np": numpy})
        except TypeError:
            raise InvalidParameterError(f"`embeddings_transformer_function` must be callable! \n"
                                        f"Instead is {result_kwargs['embeddings_transformer_function']}\n"
                                        f"Most likely you want a lambda function.")

        if not callable(transform_function):
            raise InvalidParameterError(f"`embeddings_transformer_function` must be callable! \n"
                                        f"Instead is {result_kwargs['embeddings_transformer_function']}\n"
                                        f"Most likely you want a lambda function.")

        template_embedding = embedder.embed("SEQVENCE")

        # Check that it works in principle
        try:
            transformed_template_embedding = transform_function(template_embedding)
        except:
            raise InvalidParameterError(f"`embeddings_transformer_function` must be valid callable! \n"
                                        f"Instead is {result_kwargs['embeddings_transformer_function']}\n"
                                        f"This function excepts when processing an embedding.")

        # Check that return can be cast to np.array
        try:
            numpy.array(transformed_template_embedding)
        except:
            raise InvalidParameterError(f"`embeddings_transformer_function` must be valid callable "
                                        f"returning numpy array compatible object! \n"
                                        f"Instead is {result_kwargs['embeddings_transformer_function']}\n"
                                        f"This function excepts when processing an embedding.")
示例#2
0
def run(**kwargs):
    """
    Run embedding protocol

    Parameters
    ----------
    kwargs arguments (* denotes optional):
        sequences_file: Where sequences live
        prefix: Output prefix for all generated files
        protocol: Which embedder to use
        mapping_file: the mapping file generated by the pipeline when remapping indexes
        stage_name: The stage name

    Returns
    -------
    Dictionary with results of stage
    """
    check_required(
        kwargs,
        [
            "protocol", "prefix", "stage_name", "remapped_sequences_file",
            "mapping_file"
        ],
    )

    if kwargs["protocol"] not in PROTOCOLS:
        raise InvalidParameterError(
            "Invalid protocol selection: {}. Valid protocols are: {}".format(
                kwargs["protocol"], ", ".join(PROTOCOLS.keys())))

    embedder_class = PROTOCOLS[kwargs["protocol"]]

    if embedder_class == UniRepEmbedder and kwargs.get("use_cpu") is not None:
        raise InvalidParameterError(
            "UniRep does not support configuring `use_cpu`")

    result_kwargs = deepcopy(kwargs)

    # Download necessary files if needed
    # noinspection PyProtectedMember
    for file in embedder_class._necessary_files:
        if not result_kwargs.get(file):
            result_kwargs[file] = get_model_file(model=embedder_class.name,
                                                 file=file)

    # noinspection PyProtectedMember
    for directory in embedder_class._necessary_directories:
        if not result_kwargs.get(directory):
            result_kwargs[directory] = get_model_directories_from_zip(
                model=embedder_class.name, directory=directory)

    result_kwargs.setdefault("max_amino_acids",
                             DEFAULT_MAX_AMINO_ACIDS[kwargs["protocol"]])

    file_manager = get_file_manager(**kwargs)
    embedder: EmbedderInterface = embedder_class(**result_kwargs)
    _check_transform_embeddings_function(embedder, result_kwargs)

    return embed_and_write_batched(embedder, file_manager, result_kwargs,
                                   kwargs.get("half_precision", False))
示例#3
0
def run(**kwargs):
    """
    Run visualize protocol

    Parameters
    ----------
    kwargs arguments (* denotes optional):
        projected_reduced_embeddings_file: A csv with columns: (index), original_id, x, y, z
        prefix: Output prefix for all generated files
        stage_name: The stage name
        protocol: Which plot to generate

    For plotly:
        projected_reduced_embeddings_file: The projected (dimensionality reduced) embeddings, normally coming from the project stage
        annotation_file: csv file with annotations
        display_unknown: Hide proteins for which there is no annotation in the annotation file (only relevant if annotation file is provided)
        merge_via_index: Set to True if in annotation_file identifiers correspond to sequence MD5 hashes
        n_components: 2D vs 3D plot

    For plot_mutagenesis:
        residue_probabilities_file: The csv with the probabilities, normally coming from the mutagenesis stage

    Returns
    -------
    Dictionary with results of stage
    """
    check_required(kwargs, ['protocol', 'prefix', 'stage_name'])

    if kwargs["protocol"] not in PROTOCOLS:
        raise InvalidParameterError(
            "Invalid protocol selection: " +
            "{}. Valid protocols are: {}".format(kwargs["protocol"], ", ".join(
                PROTOCOLS.keys())))

    result_kwargs = deepcopy(kwargs)

    if kwargs["protocol"] == "plotly":
        # Support legacy projected_embeddings_file
        projected_reduced_embeddings_file = (
            kwargs.get("projected_reduced_embeddings_file")
            or kwargs.get("projected_embeddings_file"))
        if not projected_reduced_embeddings_file:
            raise InvalidParameterError(
                f"You need to provide either projected_reduced_embeddings_file or projected_embeddings_file or "
                f"reduced_embeddings_file for {kwargs['protocol']}")
        result_kwargs[
            "projected_reduced_embeddings_file"] = projected_reduced_embeddings_file

    return PROTOCOLS[kwargs["protocol"]](result_kwargs)
示例#4
0
def run(**kwargs):
    """
    Run project protocol

    Parameters
    ----------
    kwargs arguments (* denotes optional):
        projected_reduced_embeddings_file or projected_embeddings_file or reduced_embeddings_file: Where per-protein embeddings live
        prefix: Output prefix for all generated files
        stage_name: The stage name
        protocol: Which projection technique to use
        mapping_file: the mapping file generated by the pipeline when remapping indexes

    Returns
    -------
    Dictionary with results of stage
    """
    check_required(kwargs, ["protocol", "prefix", "stage_name", "mapping_file"])

    if kwargs["protocol"] not in PROTOCOLS:
        raise InvalidParameterError(
            "Invalid protocol selection: "
            + "{}. Valid protocols are: {}".format(
                kwargs["protocol"], ", ".join(PROTOCOLS.keys())
            )
        )

    result_kwargs = deepcopy(kwargs)

    # We want to allow chaining protocols, e.g. first tucker than umap,
    # so we need to allow projected embeddings as input
    embeddings_input_file = (
        kwargs.get("projected_reduced_embeddings_file")
        or kwargs.get("projected_embeddings_file")
        or kwargs.get("reduced_embeddings_file")
    )
    if not embeddings_input_file:
        raise InvalidParameterError(
            f"You need to provide either projected_reduced_embeddings_file or projected_embeddings_file or "
            f"reduced_embeddings_file for {kwargs['protocol']}"
        )
    result_kwargs["reduced_embeddings_file"] = embeddings_input_file

    file_manager = get_file_manager(**kwargs)

    result_kwargs = PROTOCOLS[kwargs["protocol"]](file_manager, result_kwargs)

    return result_kwargs
示例#5
0
def run(**kwargs):
    """
    Run project protocol

    Parameters
    ----------
    kwargs arguments (* denotes optional):
        reduced_embeddings_file: Where per-protein embeddings live
        prefix: Output prefix for all generated files
        stage_name: The stage name
        protocol: Which projection technique to use
        mapping_file: the mapping file generated by the pipeline when remapping indexes

    Returns
    -------
    Dictionary with results of stage
    """
    check_required(kwargs, [
        'protocol', 'prefix', 'stage_name', 'reduced_embeddings_file',
        'mapping_file'
    ])

    if kwargs["protocol"] not in PROTOCOLS:
        raise InvalidParameterError(
            "Invalid protocol selection: " +
            "{}. Valid protocols are: {}".format(kwargs["protocol"], ", ".join(
                PROTOCOLS.keys())))

    return PROTOCOLS[kwargs["protocol"]](**kwargs)
示例#6
0
def run(**kwargs):
    """
    Run visualize protocol

    Parameters
    ----------
    kwargs arguments (* denotes optional):
        projected_embeddings_file: A csv with columns: (index), original_id, x, y, z
        prefix: Output prefix for all generated files
        stage_name: The stage name
        protocol: Which plot to generate

    Returns
    -------
    Dictionary with results of stage
    """
    check_required(kwargs, ['protocol', 'prefix', 'stage_name', 'projected_embeddings_file'])

    if kwargs["protocol"] not in PROTOCOLS:
        raise InvalidParameterError(
            "Invalid protocol selection: " +
            "{}. Valid protocols are: {}".format(
                kwargs["protocol"], ", ".join(PROTOCOLS.keys())
            )
        )

    return PROTOCOLS[kwargs["protocol"]](**kwargs)
示例#7
0
def _get_embeddings_file_context(
    file_manager: FileManagerInterface, result_kwargs: Dict[str, Any]
):
    """
    :param file_manager: The FileManager derived class which will be used to create the file
    :param result_kwargs: A dictionary which will be updated in-place to include the path to the newly created file

    :return: a file context
    """

    result_kwargs.setdefault("discard_per_amino_acid_embeddings", False)

    if result_kwargs["discard_per_amino_acid_embeddings"] is True:
        if result_kwargs.get("reduce", False) is False and result_kwargs.get("embeddings_transformer_function") is None:
            raise InvalidParameterError(
                "Cannot only have discard_per_amino_acid_embeddings: True. "
                "Either also set `reduce: True` or define an `embeddings_transformer_function`, or both."
            )
        return nullcontext()
    else:
        embeddings_file_path = file_manager.create_file(
            result_kwargs.get("prefix"),
            result_kwargs.get("stage_name"),
            "embeddings_file",
            extension=".h5",
        )
        result_kwargs["embeddings_file"] = embeddings_file_path
        return h5py.File(embeddings_file_path, "w")
示例#8
0
def _validate_file(file_path: str):
    """
    Verify if a file exists and is not empty.
    Parameters
    ----------
    file_path : str
        Path to file to check
    Returns
    -------
    bool
        True if file exists and is non-zero size,
        False otherwise.
    """
    try:
        if os.stat(file_path).st_size == 0:
            raise InvalidParameterError(f"The file at '{file_path}' is empty")
    except (OSError, TypeError) as e:
        raise InvalidParameterError(
            f"The configuration file at '{file_path}' does not exist") from e
示例#9
0
def run(**kwargs):
    """
    Run visualize protocol

    Parameters
    ----------
    kwargs arguments (* denotes optional):
        projected_reduced_embeddings_file: A csv with columns: (index), original_id, x, y, z
        prefix: Output prefix for all generated files
        stage_name: The stage name
        protocol: Which plot to generate

    Returns
    -------
    Dictionary with results of stage
    """
    check_required(kwargs, ['protocol', 'prefix', 'stage_name'])

    if kwargs["protocol"] not in PROTOCOLS:
        raise InvalidParameterError(
            "Invalid protocol selection: " +
            "{}. Valid protocols are: {}".format(
                kwargs["protocol"], ", ".join(PROTOCOLS.keys())
            )
        )

    result_kwargs = deepcopy(kwargs)

    # Support legacy projected_embeddings_file
    projected_reduced_embeddings_file = (
        kwargs.get("projected_reduced_embeddings_file")
        or kwargs.get("projected_embeddings_file")
    )
    if not projected_reduced_embeddings_file:
        raise InvalidParameterError(
            f"You need to provide either projected_reduced_embeddings_file or projected_embeddings_file or "
            f"reduced_embeddings_file for {kwargs['protocol']}"
        )
    result_kwargs["projected_reduced_embeddings_file"] = projected_reduced_embeddings_file

    return PROTOCOLS[kwargs["protocol"]](result_kwargs)
示例#10
0
def prepare_kwargs(**kwargs):
    required_kwargs = [
        "protocol",
        "prefix",
        "stage_name",
        "remapped_sequences_file",
        "mapping_file",
    ]
    check_required(kwargs, required_kwargs)

    if kwargs["protocol"] not in name_to_embedder:
        if kwargs["protocol"] in ALL_PROTOCOLS:
            raise InvalidParameterError(
                f"The extra for the protocol {kwargs['protocol']} is missing. "
                "See https://docs.bioembeddings.com/#installation on how to install all extras"
            )
        raise InvalidParameterError(
            "Invalid protocol selection: {}. Valid protocols are: {}".format(
                kwargs["protocol"], ", ".join(name_to_embedder.keys())))

    embedder_class = name_to_embedder[kwargs["protocol"]]

    if kwargs["protocol"] == "unirep" and kwargs.get("use_cpu") is not None:
        raise InvalidParameterError(
            "UniRep does not support configuring `use_cpu`")
    # See parameter_blueprints.yml
    global_options = {"sequences_file", "simple_remapping", "start_time"}
    embed_options = {
        "decoder",
        "device",
        "discard_per_amino_acid_embeddings",
        "half_precision_model",
        "half_precision",
        "max_amino_acids",
        "reduce",
        "type",
    }
    known_parameters = (set(required_kwargs)
                        | global_options
                        | embed_options
                        | set(embedder_class.necessary_files)
                        | set(embedder_class.necessary_directories))
    if embedder_class == "seqvec":
        # We support two ways of configuration for seqvec
        known_parameters.add("model_directory")
    if not set(kwargs) < known_parameters:
        # Complain louder if the input looks fishier
        for option in set(kwargs) - known_parameters:
            logger.warning(
                f"You set an unknown option for {embedder_class.name}: {option} (value: {kwargs[option]})"
            )

    if kwargs.get("half_precision_model"):
        if kwargs["protocol"] not in [
                "prottrans_t5_bfd", "prottrans_t5_uniref50"
        ]:
            raise InvalidParameterError(
                "`half_precision_model` is only supported with prottrans_t5_bfd and prottrans_t5_uniref50"
            )

        if kwargs.get("half_precision") is False:  # None remains allowed
            raise InvalidParameterError(
                "You can't have `half_precision_model` be true and `half_precision` be false. "
                "We suggest also setting `half_precision` to true, "
                "which will compute and save embeddings as half-precision floats"
            )

    result_kwargs = deepcopy(kwargs)
    result_kwargs.setdefault("max_amino_acids",
                             DEFAULT_MAX_AMINO_ACIDS[kwargs["protocol"]])

    return embedder_class, result_kwargs
示例#11
0
def plotly(result_kwargs: Dict[str, Any]) -> Dict[str, Any]:
    file_manager = get_file_manager(**result_kwargs)

    # 2 or 3D plot? Usually, this is directly fetched from "project" stage via "depends_on"
    result_kwargs['n_components'] = result_kwargs.get('n_components', 3)
    if result_kwargs['n_components'] < 2:
        raise TooFewComponentsException(f"n_components is set to {result_kwargs['n_components']}. It should be >1.\n"
                                        f"If set to 2, will render 2D scatter plot.\n"
                                        f"If set to >3, will render 3D scatter plot.")
    # Get projected_embeddings_file containing x,y,z coordinates and identifiers
    suffix = Path(result_kwargs["projected_reduced_embeddings_file"]).suffix
    if suffix == ".csv":
        # Support the legacy csv format
        merged_annotation_file = read_csv(
            result_kwargs["projected_reduced_embeddings_file"], index_col=0
        )
    elif suffix == ".h5":
        # convert h5 to dataframe with ids and one column per dimension
        rows = []
        with h5py.File(result_kwargs["projected_reduced_embeddings_file"], "r") as file:
            for sequence_id, embedding in file.items():
                if embedding.shape != (3,):
                    raise RuntimeError(
                        f"Expected embeddings in projected_reduced_embeddings_file "
                        f"to be of shape (3,), not {embedding.shape}"
                    )
                row = (
                    sequence_id,
                    embedding.attrs["original_id"],
                    embedding[0],
                    embedding[1],
                    embedding[2],
                )
                rows.append(row)
        columns = [
            "sequence_id",
            "original_id",
            "component_0",
            "component_1",
            "component_2",
        ]
        merged_annotation_file = DataFrame.from_records(
            rows, index="sequence_id", columns=columns
        )
    else:
        raise InvalidParameterError(
            f"Expected .csv or .h5 as suffix for projected_reduced_embeddings_file, got {suffix}"
        )

    if result_kwargs.get('annotation_file'):

        annotation_file = read_csv(result_kwargs['annotation_file']).set_index('identifier')

        # Save a copy of the annotation file with index set to identifier
        input_annotation_file_path = file_manager.create_file(result_kwargs.get('prefix'),
                                                              result_kwargs.get('stage_name'),
                                                              'input_annotation_file',
                                                              extension='.csv')

        annotation_file.to_csv(input_annotation_file_path)

        # Merge annotation file and projected embedding file based on index or original id?
        result_kwargs['merge_via_index'] = result_kwargs.get('merge_via_index', False)

        # Display proteins with unknown annotation?
        result_kwargs['display_unknown'] = result_kwargs.get('display_unknown', True)

        if result_kwargs['merge_via_index']:
            if result_kwargs['display_unknown']:
                merged_annotation_file = annotation_file.join(merged_annotation_file, how="outer")
                merged_annotation_file['label'].fillna('UNKNOWN', inplace=True)
            else:
                merged_annotation_file = annotation_file.join(merged_annotation_file)
        else:
            if result_kwargs['display_unknown']:
                merged_annotation_file = annotation_file.join(merged_annotation_file.set_index('original_id'),
                                                              how="outer")
                merged_annotation_file['label'].fillna('UNKNOWN', inplace=True)
            else:
                merged_annotation_file = annotation_file.join(merged_annotation_file.set_index('original_id'))
    else:
        merged_annotation_file['label'] = 'UNKNOWN'

    merged_annotation_file_path = file_manager.create_file(result_kwargs.get('prefix'),
                                                           result_kwargs.get('stage_name'),
                                                           'merged_annotation_file',
                                                           extension='.csv')

    merged_annotation_file.to_csv(merged_annotation_file_path)
    result_kwargs['merged_annotation_file'] = merged_annotation_file_path

    if result_kwargs['n_components'] == 2:
        figure = render_scatter_plotly(merged_annotation_file)
    else:
        figure = render_3D_scatter_plotly(merged_annotation_file)

    plot_file_path = file_manager.create_file(result_kwargs.get('prefix'),
                                              result_kwargs.get('stage_name'),
                                              'plot_file',
                                              extension='.html')

    save_plotly_figure_to_html(figure, plot_file_path)
    result_kwargs['plot_file'] = plot_file_path

    return result_kwargs