def test_set_reactants_single(): mol = TreeMolecule(parent=None, smiles="N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F") reactant1 = TreeMolecule(parent=mol, smiles="N#Cc1cccc(N)c1F") rxn = FixedRetroReaction(mol) rxn.reactants = reactant1 assert rxn.reactants == ((reactant1), )
def test_set_reactants_list_of_list(): mol = TreeMolecule(parent=None, smiles="N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F") reactant1 = TreeMolecule(parent=mol, smiles="N#Cc1cccc(N)c1F") reactant2 = TreeMolecule(parent=mol, smiles="O=C(Cl)c1ccc(F)cc1") rxn = FixedRetroReaction(mol) rxn.reactants = ((reactant1, reactant2), ) assert rxn.reactants == ((reactant1, reactant2), )
def test_chaining(): serializer = MoleculeSerializer() mol1 = TreeMolecule(parent=None, smiles="CCC", transform=1) mol2 = TreeMolecule(smiles="CCO", parent=mol1) id_ = serializer[mol2] deserializer = MoleculeDeserializer(serializer.store) assert deserializer[id_].smiles == mol2.smiles assert deserializer[id(mol1)].smiles == mol1.smiles assert id(deserializer[id_]) != id_ assert id(deserializer[id(mol1)]) != id(mol1)
def test_get_actions(default_config, setup_template_expansion_policy): strategy, _ = setup_template_expansion_policy() expansion_policy = default_config.expansion_policy expansion_policy.load(strategy) mols = [TreeMolecule(smiles="CCO", parent=None)] with pytest.raises(PolicyException, match="selected"): expansion_policy.get_actions(mols) expansion_policy.select("policy1") actions, priors = expansion_policy.get_actions(mols) assert priors == [0.7, 0.2] policy_names = [action.metadata["policy_name"] for action in actions] assert policy_names == ["policy1", "policy1"] expansion_policy._config.cutoff_cumulative = 1.0 actions, priors = expansion_policy.get_actions(mols) assert priors == [0.7, 0.2, 0.1] expansion_policy._config.cutoff_number = 1 actions, priors = expansion_policy.get_actions(mols) assert priors == [0.7]
def wrapper(root_smiles, config): mol = TreeMolecule(parent=None, transform=0, smiles=root_smiles) state = State(mols=[mol], config=config) mocked_create_root.return_value = Node(state=state, owner=None, config=config) return mol
def test_retro_reaction_fingerprint(simple_actions): mol = TreeMolecule(parent=None, smiles="CCCCOc1ccc(CC(=O)N(C)O)cc1") reactions, _ = simple_actions(mol) fp = reactions[0].fingerprint(2, 10) assert list(fp) == [0, -1, 0, -1, -1, 0, -1, -1, 0, 0]
def test_save_molecule_images(): nfiles = len(os.listdir(image.IMAGE_FOLDER)) mols = [ TreeMolecule(smiles="CCCO", parent=None), TreeMolecule(smiles="CCCO", parent=None), TreeMolecule(smiles="CCCCO", parent=None), ] image.save_molecule_images(mols, ["green", "green", "green"]) assert len(os.listdir(image.IMAGE_FOLDER)) == nfiles + 2 image.save_molecule_images(mols, ["green", "orange", "green"]) assert len(os.listdir(image.IMAGE_FOLDER)) == nfiles + 3
def _score_reaction_tree(self, tree: ReactionTree) -> float: mols = [ TreeMolecule(parent=None, transform=tree.depth(leaf) // 2, smiles=leaf.smiles) for leaf in tree.leafs() ] state = MctsState(mols, self._config) return state.score
def test_smiles_based_retroreaction(): mol = TreeMolecule(smiles="CNC(C)=O", parent=None) reaction = SmilesBasedRetroReaction(mol, reactants_str="CC(=O)O.CN") assert len(reaction.reactants) == 1 assert reaction.reactants[0][0].smiles == "CC(=O)O" assert reaction.reactants[0][1].smiles == "CN" assert reaction.smiles == "CNC(C)=O>>CC(=O)O.CN"
def test_reaction_failure_rdchiral(simple_actions, mocker): patched_rchiral_run = mocker.patch("aizynthfinder.chem.rdc.rdchiralRun") patched_rchiral_run.side_effect = RuntimeError("Oh no!") mol = TreeMolecule(parent=None, smiles="CCCCOc1ccc(CC(=O)N(C)O)cc1") reactions, _ = simple_actions(mol) products = reactions[0].apply() assert not products
def setup_graphviz_graph(): mol1 = TreeMolecule(smiles="CCCO", parent=None) reaction = RetroReaction(mol=mol1, smarts="") graph = image.GraphvizReactionGraph() graph.add_molecule(mol1, "green") graph.add_reaction(reaction) graph.add_edge(mol1, reaction) return graph
def test_retro_reaction(simple_actions): mol = TreeMolecule(parent=None, smiles="CCCCOc1ccc(CC(=O)N(C)O)cc1") reactions, _ = simple_actions(mol) products1 = reactions[0].apply() assert products1[0][0].smiles == "CCCCOc1ccc(CC(=O)Cl)cc1" assert products1[0][1].smiles == "CNO" products2 = reactions[2].apply() assert products2 == ()
def test_expander_top1(mock_expansion_policy): expander = AiZynthExpander() smi = "CCCCOc1ccc(CC(=O)N(C)O)cc1" mock_expansion_policy(TreeMolecule(parent=None, smiles=smi)) reactions = expander.do_expansion(smi, return_n=1) assert len(reactions) == 1 smiles_list = [mol.smiles for mol in reactions[0][0].reactants[0]] assert smiles_list == ["CCCCOc1ccc(CC(=O)Cl)cc1", "CNO"]
def test_feasible(filter_policy, mock_policy_model, mocker, simple_actions): filter_policy.load("dummy.hdf5", "policy1") filter_policy.select("policy1") mol = TreeMolecule(parent=None, smiles="CCCCOc1ccc(CC(=O)N(C)O)cc1") reactions, _ = simple_actions(mol) filter_policy._config.filter_cutoff = 0.9 assert not filter_policy.is_feasible(reactions[0]) filter_policy._config.filter_cutoff = 0.15 assert filter_policy.is_feasible(reactions[0])
def test_expander_filter_policy(mock_expansion_policy, mock_policy_model): expander = AiZynthExpander() smi = "CCCCOc1ccc(CC(=O)N(C)O)cc1" mock_expansion_policy(TreeMolecule(parent=None, smiles=smi)) expander.filter_policy.load("dummy.hdf5", "policy1") expander.filter_policy.select("policy1") reactions = expander.do_expansion(smi) assert len(reactions) == 2 assert reactions[0][0].metadata["feasibility"] == 0.2 assert reactions[1][0].metadata["feasibility"] == 0.2
def test_create_fixed_reaction(): smiles = "[C:1](=[O:2])([cH3:3])[N:4][cH3:5]>>Cl[C:1](=[O:2])[cH3:3].[N:4][cH3:5]" mol = TreeMolecule(parent=None, smiles="N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F") rxn = FixedRetroReaction(mol, smiles=smiles) assert rxn.smiles == smiles with pytest.raises(ValueError): rxn.rd_reaction with pytest.raises(NotImplementedError): rxn.apply()
def create_root(cls, smiles: str, tree: MctsSearchTree, config: Configuration) -> "MctsNode": """ Create a root node for a tree using a SMILES. :param smiles: the SMILES representation of the root state :param tree: the search tree :param config: settings of the tree search algorithm :return: the created node """ mol = TreeMolecule(parent=None, transform=0, smiles=smiles) state = MctsState(mols=[mol], config=config) return MctsNode(state=state, owner=tree, config=config)
def test_expander_filter(mock_expansion_policy): def filter_func(reaction): return "CNO" not in [mol.smiles for mol in reaction.reactants[0]] expander = AiZynthExpander() smi = "CCCCOc1ccc(CC(=O)N(C)O)cc1" mock_expansion_policy(TreeMolecule(parent=None, smiles=smi)) reactions = expander.do_expansion(smi, filter_func=filter_func) assert len(reactions) == 1 smiles_list = [mol.smiles for mol in reactions[0][0].reactants[0]] assert smiles_list == ["CCCCBr", "CN(O)C(=O)Cc1ccc(O)cc1"]
def do_expansion( self, smiles: str, return_n: int = 5, filter_func: Callable[[RetroReaction], bool] = None, ) -> List[Tuple[FixedRetroReaction, ...]]: """ Do the expansion of the given molecule returning a list of reaction tuples. Each tuple in the list contains reactions producing the same reactants. Hence, nested structure of the return value is way to group reactions. If filter policy is setup, the probability of the reactions are added as metadata to the reaction. The additional filter functions makes it possible to do customized filtering. The callable should take as only argument a `RetroReaction` object and return True if the reaction can be kept or False if it should be removed. :param smiles: the SMILES string of the target molecule :param return_n: the length of the return list :param filter_func: an additional filter function :return: the grouped reactions """ mol = TreeMolecule(parent=None, smiles=smiles) actions, _ = self.expansion_policy.get_actions([mol]) results: Dict[Tuple[str, ...], List[FixedRetroReaction]] = defaultdict(list) for action in actions: reactants = action.reactants if not reactants: continue if filter_func and not filter_func(action): continue for name in self.filter_policy.selection or []: if hasattr(self.filter_policy[name], "feasibility"): _, feasibility_prob = self.filter_policy[name].feasibility( action) action.metadata["feasibility"] = float(feasibility_prob) break action.metadata["expansion_rank"] = len(results) + 1 unique_key = tuple(sorted(mol.inchi_key for mol in reactants[0])) if unique_key not in results and len(results) >= return_n: continue rxn = next(ReactionTreeFromExpansion( action).tree.reactions()) # type: ignore results[unique_key].append(rxn) return [tuple(reactions) for reactions in results.values()]
def test_add_tree_mol(): serializer = MoleculeSerializer() mol1 = TreeMolecule(parent=None, smiles="CCC", transform=1) mol2 = TreeMolecule(smiles="CCO", parent=mol1) id_ = serializer[mol2] assert id_ == id(mol2) assert list(serializer.store.keys()) == [id(mol1), id_] assert serializer.store == { id_: { "smiles": "CCO", "class": "TreeMolecule", "parent": id(mol1), "transform": 2, }, id(mol1): { "smiles": "CCC", "class": "TreeMolecule", "parent": None, "transform": 1, }, }
def wrapper(parent, key_smiles, child_smiles_list, probs): rxn_objs = [] mol_objs_list = [] for child_smiles in child_smiles_list: if not child_smiles: rxn_objs.append(mocked_reaction(parent, None)) continue mol_objs = [ TreeMolecule(parent=parent, smiles=smiles) for smiles in child_smiles ] mol_objs_list.append(mol_objs) rxn_objs.append(mocked_reaction(parent, mol_objs)) actions[key_smiles] = rxn_objs, probs return mol_objs_list
def test_retro_reaction_copy(simple_actions): mol = TreeMolecule(parent=None, smiles="CCCCOc1ccc(CC(=O)N(C)O)cc1") reactions, _ = simple_actions(mol) reactions[0].apply() copy_ = reactions[0].copy() assert copy_.mol.smiles == reactions[0].mol.smiles assert len(copy_.reactants[0]) == len(reactions[0].reactants[0]) assert copy_.index == reactions[0].index copy_ = reactions[0].copy(index=2) assert copy_.mol.smiles == reactions[0].mol.smiles assert len(copy_.reactants[0]) == len(reactions[0].reactants[0]) assert copy_.index != reactions[0].index
def test_serialize_deserialize_state(default_config): mol = TreeMolecule(parent=None, smiles="CCC", transform=1) state0 = State([mol], default_config) serializer = MoleculeSerializer() state_serialized = state0.serialize(serializer) assert len(state_serialized["mols"]) == 1 assert state_serialized["mols"][0] == id(mol) deserializer = MoleculeDeserializer(serializer.store) state1 = State.from_dict(state_serialized, default_config, deserializer) assert len(state1.mols) == 1 assert state1.mols[0] == state0.mols[0] assert state1.in_stock_list == state0.in_stock_list assert state1.score == state0.score
def test_retro_reaction_copy(get_action): mol = TreeMolecule(parent=None, smiles="CCCCOc1ccc(CC(=O)N(C)O)cc1") reaction = get_action() _ = reaction.reactants copy_ = reaction.copy() assert isinstance(copy_, TemplatedRetroReaction) assert copy_.mol.smiles == reaction.mol.smiles assert len(copy_.reactants[0]) == len(reaction.reactants[0]) assert copy_.index == reaction.index copy_ = reaction.copy(index=2) assert copy_.mol.smiles == reaction.mol.smiles assert len(copy_.reactants[0]) == len(reaction.reactants[0]) assert copy_.index != reaction.index
def test_expander_defaults(mock_expansion_policy): expander = AiZynthExpander() smi = "CCCCOc1ccc(CC(=O)N(C)O)cc1" mock_expansion_policy(TreeMolecule(parent=None, smiles=smi)) reactions = expander.do_expansion(smi) assert len(reactions) == 2 assert len(reactions[0]) == 1 assert len(reactions[1]) == 1 assert reactions[0][0].mol.smiles == smi assert reactions[1][0].mol.smiles == smi assert len(reactions[0][0].reactants[0]) == 2 assert len(reactions[1][0].reactants[0]) == 2 smi1 = [mol.smiles for mol in reactions[0][0].reactants[0]] smi2 = [mol.smiles for mol in reactions[1][0].reactants[0]] assert smi1 != smi2
def test_filter_rejection(default_config, mock_keras_model): filter_policy = default_config.filter_policy filter_policy.load_from_config(**{"files": {"policy1": "dummy1"}}) mol = TreeMolecule(parent=None, smiles="CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1") reaction = SmilesBasedRetroReaction( mol, reactants_str="CN1CCC(Cl)CC1.N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F.O") with pytest.raises(PolicyException, match="selected"): filter_policy(reaction) filter_policy.select("policy1") filter_policy._config.filter_cutoff = 0.9 with pytest.raises(RejectionException): filter_policy(reaction) filter_policy._config.filter_cutoff = 0.15 filter_policy(reaction)
def test_reactants_count_rejection(default_config): smarts = ( "([C:3]-[N;H0;D2;+0:2]=[C;H0;D3;+0:1](-[c:4]1:[c:5]:[c:6]:[c:7]:[c:8]:[c:9]:1)-[c;H0;D3;+0:11](:[c:10]):[c:12])>>" "(O=[C;H0;D3;+0:1](-[NH;D2;+0:2]-[C:3])-[c:4]1:[c:5]:[c:6]:[c:7]:[c:8]:[c:9]:1.[c:10]:[cH;D2;+0:11]:[c:12])" ) mol = TreeMolecule(parent=None, smiles="c1c2c(ccc1)CCN=C2c3ccccc3") rxn1 = TemplatedRetroReaction(mol=mol, smarts=smarts) filter = ReactantsCountFilter("dummy", default_config) assert len(rxn1.reactants) == 2 rxn2 = rxn1.copy(index=1) if len(rxn1.reactants[0]) == 1: rxn1, rxn2 = rxn2, rxn1 assert filter(rxn2) is None with pytest.raises(RejectionException): filter(rxn1)
def get_action(): smi = "CCCCOc1ccc(CC(=O)N(C)O)cc1" mol = TreeMolecule(smiles=smi, parent=None) def wrapper(applicable=True): if applicable: smarts = ( "([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])" ">>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6])" ) else: smarts = ( "([C:4]-[N;H0;D3;+0:5](-[C:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H03])>>" "(O-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([C:4]-[NH;D2;+0:5]-[C:6])" ) return TemplatedRetroReaction(mol, smarts=smarts, metadata={"dummy": 1}) return wrapper
def test_get_actions_two_policies(default_config, setup_template_expansion_policy): expansion_policy = default_config.expansion_policy strategy1, _ = setup_template_expansion_policy("policy1") expansion_policy.load(strategy1) strategy2, _ = setup_template_expansion_policy("policy2") expansion_policy.load(strategy2) default_config.additive_expansion = True expansion_policy.select(["policy1", "policy2"]) mols = [TreeMolecule(smiles="CCO", parent=None)] actions, priors = expansion_policy.get_actions(mols) policy_names = [action.metadata["policy_name"] for action in actions] assert policy_names == ["policy1"] * 2 + ["policy2"] * 2 assert priors == [0.7, 0.2, 0.7, 0.2] expansion_policy._config.cutoff_cumulative = 1.0 actions, priors = expansion_policy.get_actions(mols) assert priors == [0.7, 0.2, 0.1, 0.7, 0.2, 0.1] expansion_policy._config.cutoff_number = 1 actions, priors = expansion_policy.get_actions(mols) assert priors == [0.7, 0.7] default_config.additive_expansion = False default_config.cutoff_number = 2 actions, priors = expansion_policy.get_actions(mols) policy_names = [action.metadata["policy_name"] for action in actions] assert policy_names == ["policy1", "policy1"] assert priors == [0.7, 0.2]
def setup_graphviz_graph(): mol1 = TreeMolecule(smiles="CCCO", parent=None) reaction = RetroReaction(mol=mol1, smarts="") return [mol1], [reaction], [(mol1, reaction)], ["green"]