Пример #1
0
def test_mle_hard(input_absolutes: list = [-14.0, -13.0, -9.0]):
    """
    Test that the MLE for a graph with a node missing an absolute value
    can get it right based on relative results
    """

    # make a t
    graph = nx.DiGraph()
    # Don't assign the first absolute value, check that MLE can get close to it
    for i, val in enumerate(input_absolutes):
        if i == 0:
            graph.add_node(i)
        else:
            graph.add_node(i, f_i=val, f_di=0.5)

    edges = [(0, 1), (0, 2), (2, 1)]
    for node1, node2 in edges:
        noise = np.random.uniform(low=-1.0, high=1.0)
        diff = input_absolutes[node2] - input_absolutes[node1] + noise
        graph.add_edge(node1, node2, f_ij=diff, f_dij=0.5 + np.abs(noise))

    output_absolutes, covar = stats.mle(graph,
                                        factor="f_ij",
                                        node_factor="f_i")

    for i, _ in enumerate(graph.nodes(data=True)):
        diff = np.abs(output_absolutes[i] - input_absolutes[i])
        assert (diff < covar[i, i]), f"MLE error. Output absolute \
Пример #2
0
def test_mle_relative(input_absolutes: list = [-14.0, -13.0, -9.0]):
    """
    Test that the MLE can get the relative differences correct
     when no absolute values are provided
    """

    graph = nx.DiGraph()
    # Don't assign any absolute values
    edges = [(0, 1), (0, 2), (2, 1)]
    for node1, node2 in edges:
        noise = np.random.uniform(low=-0.5, high=0.5)
        diff = input_absolutes[node2] - input_absolutes[node1] + noise
        graph.add_edge(node1, node2, f_ij=diff, f_dij=0.5 + np.abs(noise))

    output_absolutes, _ = stats.mle(graph, factor="f_ij", node_factor="f_i")

    pairs = itertools.combinations(range(len(input_absolutes)), 2)

    for i, j in pairs:
        mle_diff = output_absolutes[i] - output_absolutes[j]
        true_diff = input_absolutes[i] - input_absolutes[j]

        assert (np.abs(true_diff - mle_diff) < 1.0), f"Relative\
Пример #3
0
def test_mle_easy(input_absolutes: list = [-14.0, -13.0, -9.0]):
    """
    Test that the MLE for a graph with an absolute
    estimate on all nodes will recapitulate it
    """

    graph = nx.DiGraph()
    for i, val in enumerate(input_absolutes):
        graph.add_node(i, f_i=val, f_di=0.5)

    edges = [(0, 1), (0, 2), (2, 1)]
    for node1, node2 in edges:
        noise = np.random.uniform(low=-1.0, high=1.0)
        diff = input_absolutes[node2] - input_absolutes[node1] + noise
        graph.add_edge(node1, node2, f_ij=diff, f_dij=0.5 + np.abs(noise))

    output_absolutes, covar = stats.mle(graph,
                                        factor="f_ij",
                                        node_factor="f_i")

    for i, _ in enumerate(graph.nodes(data=True)):
        diff = np.abs(output_absolutes[i] - input_absolutes[i])
        assert (diff < covar[i, i]), f"MLE error. Output absolute \
Пример #4
0
def combine_free_energies(
    compounds: List[Compound],
    transformations: List[TransformationAnalysis],
) -> List[CompoundAnalysis]:
    """
    Perform DiffNet MLE analysis to compute free energies for all
    microstates given experimental free energies for a subset, and
    relative free energies of transformations.

    Parameters
    ----------
    compounds : list of Compound
    transformations : list of Transformation

    Returns
    -------
    List of CompoundAnalysis
        Result of DiffNet MLE analysis
    """

    from openff.arsenic import stats

    # Type assertions (useful for type checking with mypy)
    node: CompoundMicrostate
    microstate: Microstate

    supergraph = build_transformation_graph(compounds, transformations)

    # Split supergraph into weakly-connected subgraphs
    # NOTE: the subgraphs are "views" into the supergraph, meaning
    # updates made to the subgraphs are reflected in the supergraph
    # (we exploit this below)
    connected_subgraphs = [
        supergraph.subgraph(nodes)
        for nodes in nx.weakly_connected_components(supergraph)
    ]

    # Filter to connected subgraphs containing at least one
    # experimental measurement
    valid_subgraphs = [
        graph for graph in connected_subgraphs if any(
            "pIC50" in graph.nodes[node]["compound"].metadata.experimental_data
            for node in graph)
    ]

    if len(valid_subgraphs) < len(connected_subgraphs):
        logging.warning(
            "Found %d out of %d connected subgraphs without experimental data",
            len(connected_subgraphs) - len(valid_subgraphs),
            len(connected_subgraphs),
        )

    # Inital MLE pass: compute microstate free energies without using
    # experimental reference values
    for idx, graph in enumerate(valid_subgraphs):
        # NOTE: no node_factor argument in the following
        # (because we do not use experimental data for the first pass)
        g1s, C1 = stats.mle(graph, factor="g_ij")
        errs = np.sqrt(np.diag(C1))
        for node, g1, g1_err in zip(graph.nodes, g1s, errs):
            graph.nodes[node]["g1"] = g1
            graph.nodes[node]["g1_err"] = g1_err
            graph.nodes[node]["subgraph_index"] = idx

    # Use first-pass microstate free energies g_1[c,i] to distribute
    # compound-level experimental data g_exp_compound[c] over
    # microstates, using the formula:
    #
    #    g_exp[c,i] = g_exp_compound[c]
    #               - ln( exp(-(s[c,i] + g_1[c,i]))
    #                   / sum(exp(-(s[c,:] + g_1[c,:])))
    #                   )
    #
    for compound in compounds:
        pIC50 = compound.metadata.experimental_data.get("pIC50")

        # Skip compounds with no experimental data
        if pIC50 is None:
            continue

        g_exp_compound = pIC50_to_DG(pIC50)

        nodes = [
            CompoundMicrostate(
                compound_id=compound.metadata.compound_id,
                microstate_id=microstate.microstate_id,
            ) for microstate in compound.microstates
        ]

        # Filter to nodes that are part of a connected subgraph with
        # at least one experimental measurement

        subgraph_valid_nodes: Dict[int, List[Tuple[CompoundMicrostate,
                                                   Microstate]]]
        subgraph_valid_nodes = defaultdict(list)
        for node, microstate in zip(nodes, compound.microstates):
            if node in supergraph and "subgraph_index" in supergraph.nodes[
                    node]:
                idx = supergraph.nodes[node]["subgraph_index"]
                subgraph_valid_nodes[idx].append((node, microstate))

        # Skip compound if none of its microstates are in a subgraph
        # with experimental data
        if not subgraph_valid_nodes:
            continue

        # Pick the subgraph containing the largest number of microstates
        valid_nodes = max(subgraph_valid_nodes.values(),
                          key=lambda ns: len(ns))

        g_is = np.array([
            microstate.free_energy_penalty.point + supergraph.nodes[node]["g1"]
            for node, microstate in valid_nodes
        ])

        # Compute normalized microstate probabilities
        p_is = np.exp(-g_is - logsumexp(-g_is))

        # Apportion compound K_a according to microstate probability
        Ka_is = p_is * np.exp(-g_exp_compound)

        for (node, _), Ka in zip(valid_nodes, Ka_is):
            if node in supergraph:
                supergraph.nodes[node]["g_exp"] = -np.log(Ka)
                # NOTE: naming of uncertainty fixed by Arsenic convention
                # TODO: remove hard-coded value
                supergraph.nodes[node]["g_dexp"] = 0.1 * KCALMOL_KT
            else:
                logging.warning(
                    "Compound microstate '%s' has experimental data, "
                    "but does not appear in any transformation",
                    node.microstate_id,
                )

    # Second pass: use first-pass microstate free energies and
    # compound experimental data to compute microstate absolute free
    # energies.

    for graph in valid_subgraphs:
        gs, C = stats.mle(graph, factor="g_ij", node_factor="g_exp")
        errs = np.sqrt(np.diag(C))
        for node, g, g_err in zip(graph.nodes, gs, errs):
            graph.nodes[node]["g"] = g
            graph.nodes[node]["g_err"] = g_err

    def get_compound_analysis(compound: Compound) -> CompoundAnalysis:
        def get_microstate_analysis(
                microstate: Microstate) -> MicrostateAnalysis:

            node = CompoundMicrostate(
                compound_id=compound.metadata.compound_id,
                microstate_id=microstate.microstate_id,
            )

            data = supergraph.nodes.get(node)

            return MicrostateAnalysis(
                microstate=microstate,
                free_energy=PointEstimate(point=data["g"],
                                          stderr=data["g_err"])
                if data and "g" in data and "g_err" in data else None,
                first_pass_free_energy=PointEstimate(point=data["g1"],
                                                     stderr=data["g1_err"])
                if data and "g1" in data and "g1_err" in data else None,
            )

        microstates = [
            get_microstate_analysis(microstate)
            for microstate in compound.microstates
        ]

        free_energy: Optional[PointEstimate]
        try:
            free_energy = get_compound_free_energy(microstates)
        except AnalysisError as exc:
            logging.info(
                "Failed to estimate free energy for compound '%s': %s",
                compound.metadata.compound_id,
                exc,
            )
            free_energy = None

        return CompoundAnalysis(metadata=compound.metadata,
                                microstates=microstates,
                                free_energy=free_energy)

    return [get_compound_analysis(compound) for compound in compounds]