def test_mle_hard(input_absolutes: list = [-14.0, -13.0, -9.0]): """ Test that the MLE for a graph with a node missing an absolute value can get it right based on relative results """ # make a t graph = nx.DiGraph() # Don't assign the first absolute value, check that MLE can get close to it for i, val in enumerate(input_absolutes): if i == 0: graph.add_node(i) else: graph.add_node(i, f_i=val, f_di=0.5) edges = [(0, 1), (0, 2), (2, 1)] for node1, node2 in edges: noise = np.random.uniform(low=-1.0, high=1.0) diff = input_absolutes[node2] - input_absolutes[node1] + noise graph.add_edge(node1, node2, f_ij=diff, f_dij=0.5 + np.abs(noise)) output_absolutes, covar = stats.mle(graph, factor="f_ij", node_factor="f_i") for i, _ in enumerate(graph.nodes(data=True)): diff = np.abs(output_absolutes[i] - input_absolutes[i]) assert (diff < covar[i, i]), f"MLE error. Output absolute \
def test_mle_relative(input_absolutes: list = [-14.0, -13.0, -9.0]): """ Test that the MLE can get the relative differences correct when no absolute values are provided """ graph = nx.DiGraph() # Don't assign any absolute values edges = [(0, 1), (0, 2), (2, 1)] for node1, node2 in edges: noise = np.random.uniform(low=-0.5, high=0.5) diff = input_absolutes[node2] - input_absolutes[node1] + noise graph.add_edge(node1, node2, f_ij=diff, f_dij=0.5 + np.abs(noise)) output_absolutes, _ = stats.mle(graph, factor="f_ij", node_factor="f_i") pairs = itertools.combinations(range(len(input_absolutes)), 2) for i, j in pairs: mle_diff = output_absolutes[i] - output_absolutes[j] true_diff = input_absolutes[i] - input_absolutes[j] assert (np.abs(true_diff - mle_diff) < 1.0), f"Relative\
def test_mle_easy(input_absolutes: list = [-14.0, -13.0, -9.0]): """ Test that the MLE for a graph with an absolute estimate on all nodes will recapitulate it """ graph = nx.DiGraph() for i, val in enumerate(input_absolutes): graph.add_node(i, f_i=val, f_di=0.5) edges = [(0, 1), (0, 2), (2, 1)] for node1, node2 in edges: noise = np.random.uniform(low=-1.0, high=1.0) diff = input_absolutes[node2] - input_absolutes[node1] + noise graph.add_edge(node1, node2, f_ij=diff, f_dij=0.5 + np.abs(noise)) output_absolutes, covar = stats.mle(graph, factor="f_ij", node_factor="f_i") for i, _ in enumerate(graph.nodes(data=True)): diff = np.abs(output_absolutes[i] - input_absolutes[i]) assert (diff < covar[i, i]), f"MLE error. Output absolute \
def combine_free_energies( compounds: List[Compound], transformations: List[TransformationAnalysis], ) -> List[CompoundAnalysis]: """ Perform DiffNet MLE analysis to compute free energies for all microstates given experimental free energies for a subset, and relative free energies of transformations. Parameters ---------- compounds : list of Compound transformations : list of Transformation Returns ------- List of CompoundAnalysis Result of DiffNet MLE analysis """ from openff.arsenic import stats # Type assertions (useful for type checking with mypy) node: CompoundMicrostate microstate: Microstate supergraph = build_transformation_graph(compounds, transformations) # Split supergraph into weakly-connected subgraphs # NOTE: the subgraphs are "views" into the supergraph, meaning # updates made to the subgraphs are reflected in the supergraph # (we exploit this below) connected_subgraphs = [ supergraph.subgraph(nodes) for nodes in nx.weakly_connected_components(supergraph) ] # Filter to connected subgraphs containing at least one # experimental measurement valid_subgraphs = [ graph for graph in connected_subgraphs if any( "pIC50" in graph.nodes[node]["compound"].metadata.experimental_data for node in graph) ] if len(valid_subgraphs) < len(connected_subgraphs): logging.warning( "Found %d out of %d connected subgraphs without experimental data", len(connected_subgraphs) - len(valid_subgraphs), len(connected_subgraphs), ) # Inital MLE pass: compute microstate free energies without using # experimental reference values for idx, graph in enumerate(valid_subgraphs): # NOTE: no node_factor argument in the following # (because we do not use experimental data for the first pass) g1s, C1 = stats.mle(graph, factor="g_ij") errs = np.sqrt(np.diag(C1)) for node, g1, g1_err in zip(graph.nodes, g1s, errs): graph.nodes[node]["g1"] = g1 graph.nodes[node]["g1_err"] = g1_err graph.nodes[node]["subgraph_index"] = idx # Use first-pass microstate free energies g_1[c,i] to distribute # compound-level experimental data g_exp_compound[c] over # microstates, using the formula: # # g_exp[c,i] = g_exp_compound[c] # - ln( exp(-(s[c,i] + g_1[c,i])) # / sum(exp(-(s[c,:] + g_1[c,:]))) # ) # for compound in compounds: pIC50 = compound.metadata.experimental_data.get("pIC50") # Skip compounds with no experimental data if pIC50 is None: continue g_exp_compound = pIC50_to_DG(pIC50) nodes = [ CompoundMicrostate( compound_id=compound.metadata.compound_id, microstate_id=microstate.microstate_id, ) for microstate in compound.microstates ] # Filter to nodes that are part of a connected subgraph with # at least one experimental measurement subgraph_valid_nodes: Dict[int, List[Tuple[CompoundMicrostate, Microstate]]] subgraph_valid_nodes = defaultdict(list) for node, microstate in zip(nodes, compound.microstates): if node in supergraph and "subgraph_index" in supergraph.nodes[ node]: idx = supergraph.nodes[node]["subgraph_index"] subgraph_valid_nodes[idx].append((node, microstate)) # Skip compound if none of its microstates are in a subgraph # with experimental data if not subgraph_valid_nodes: continue # Pick the subgraph containing the largest number of microstates valid_nodes = max(subgraph_valid_nodes.values(), key=lambda ns: len(ns)) g_is = np.array([ microstate.free_energy_penalty.point + supergraph.nodes[node]["g1"] for node, microstate in valid_nodes ]) # Compute normalized microstate probabilities p_is = np.exp(-g_is - logsumexp(-g_is)) # Apportion compound K_a according to microstate probability Ka_is = p_is * np.exp(-g_exp_compound) for (node, _), Ka in zip(valid_nodes, Ka_is): if node in supergraph: supergraph.nodes[node]["g_exp"] = -np.log(Ka) # NOTE: naming of uncertainty fixed by Arsenic convention # TODO: remove hard-coded value supergraph.nodes[node]["g_dexp"] = 0.1 * KCALMOL_KT else: logging.warning( "Compound microstate '%s' has experimental data, " "but does not appear in any transformation", node.microstate_id, ) # Second pass: use first-pass microstate free energies and # compound experimental data to compute microstate absolute free # energies. for graph in valid_subgraphs: gs, C = stats.mle(graph, factor="g_ij", node_factor="g_exp") errs = np.sqrt(np.diag(C)) for node, g, g_err in zip(graph.nodes, gs, errs): graph.nodes[node]["g"] = g graph.nodes[node]["g_err"] = g_err def get_compound_analysis(compound: Compound) -> CompoundAnalysis: def get_microstate_analysis( microstate: Microstate) -> MicrostateAnalysis: node = CompoundMicrostate( compound_id=compound.metadata.compound_id, microstate_id=microstate.microstate_id, ) data = supergraph.nodes.get(node) return MicrostateAnalysis( microstate=microstate, free_energy=PointEstimate(point=data["g"], stderr=data["g_err"]) if data and "g" in data and "g_err" in data else None, first_pass_free_energy=PointEstimate(point=data["g1"], stderr=data["g1_err"]) if data and "g1" in data and "g1_err" in data else None, ) microstates = [ get_microstate_analysis(microstate) for microstate in compound.microstates ] free_energy: Optional[PointEstimate] try: free_energy = get_compound_free_energy(microstates) except AnalysisError as exc: logging.info( "Failed to estimate free energy for compound '%s': %s", compound.metadata.compound_id, exc, ) free_energy = None return CompoundAnalysis(metadata=compound.metadata, microstates=microstates, free_energy=free_energy) return [get_compound_analysis(compound) for compound in compounds]