示例#1
0
def test_set_reactants_single():
    mol = TreeMolecule(parent=None, smiles="N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F")
    reactant1 = TreeMolecule(parent=mol, smiles="N#Cc1cccc(N)c1F")
    rxn = FixedRetroReaction(mol)

    rxn.reactants = reactant1

    assert rxn.reactants == ((reactant1), )
示例#2
0
def test_set_reactants_list_of_list():
    mol = TreeMolecule(parent=None, smiles="N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F")
    reactant1 = TreeMolecule(parent=mol, smiles="N#Cc1cccc(N)c1F")
    reactant2 = TreeMolecule(parent=mol, smiles="O=C(Cl)c1ccc(F)cc1")
    rxn = FixedRetroReaction(mol)

    rxn.reactants = ((reactant1, reactant2), )

    assert rxn.reactants == ((reactant1, reactant2), )
def test_chaining():
    serializer = MoleculeSerializer()
    mol1 = TreeMolecule(parent=None, smiles="CCC", transform=1)
    mol2 = TreeMolecule(smiles="CCO", parent=mol1)

    id_ = serializer[mol2]

    deserializer = MoleculeDeserializer(serializer.store)

    assert deserializer[id_].smiles == mol2.smiles
    assert deserializer[id(mol1)].smiles == mol1.smiles
    assert id(deserializer[id_]) != id_
    assert id(deserializer[id(mol1)]) != id(mol1)
示例#4
0
def test_get_actions(default_config, setup_template_expansion_policy):
    strategy, _ = setup_template_expansion_policy()
    expansion_policy = default_config.expansion_policy
    expansion_policy.load(strategy)

    mols = [TreeMolecule(smiles="CCO", parent=None)]

    with pytest.raises(PolicyException, match="selected"):
        expansion_policy.get_actions(mols)

    expansion_policy.select("policy1")
    actions, priors = expansion_policy.get_actions(mols)

    assert priors == [0.7, 0.2]
    policy_names = [action.metadata["policy_name"] for action in actions]
    assert policy_names == ["policy1", "policy1"]

    expansion_policy._config.cutoff_cumulative = 1.0
    actions, priors = expansion_policy.get_actions(mols)

    assert priors == [0.7, 0.2, 0.1]

    expansion_policy._config.cutoff_number = 1
    actions, priors = expansion_policy.get_actions(mols)

    assert priors == [0.7]
 def wrapper(root_smiles, config):
     mol = TreeMolecule(parent=None, transform=0, smiles=root_smiles)
     state = State(mols=[mol], config=config)
     mocked_create_root.return_value = Node(state=state,
                                            owner=None,
                                            config=config)
     return mol
示例#6
0
def test_retro_reaction_fingerprint(simple_actions):
    mol = TreeMolecule(parent=None, smiles="CCCCOc1ccc(CC(=O)N(C)O)cc1")
    reactions, _ = simple_actions(mol)

    fp = reactions[0].fingerprint(2, 10)

    assert list(fp) == [0, -1, 0, -1, -1, 0, -1, -1, 0, 0]
示例#7
0
def test_save_molecule_images():
    nfiles = len(os.listdir(image.IMAGE_FOLDER))

    mols = [
        TreeMolecule(smiles="CCCO", parent=None),
        TreeMolecule(smiles="CCCO", parent=None),
        TreeMolecule(smiles="CCCCO", parent=None),
    ]

    image.save_molecule_images(mols, ["green", "green", "green"])

    assert len(os.listdir(image.IMAGE_FOLDER)) == nfiles + 2

    image.save_molecule_images(mols, ["green", "orange", "green"])

    assert len(os.listdir(image.IMAGE_FOLDER)) == nfiles + 3
示例#8
0
 def _score_reaction_tree(self, tree: ReactionTree) -> float:
     mols = [
         TreeMolecule(parent=None,
                      transform=tree.depth(leaf) // 2,
                      smiles=leaf.smiles) for leaf in tree.leafs()
     ]
     state = MctsState(mols, self._config)
     return state.score
示例#9
0
def test_smiles_based_retroreaction():
    mol = TreeMolecule(smiles="CNC(C)=O", parent=None)
    reaction = SmilesBasedRetroReaction(mol, reactants_str="CC(=O)O.CN")

    assert len(reaction.reactants) == 1
    assert reaction.reactants[0][0].smiles == "CC(=O)O"
    assert reaction.reactants[0][1].smiles == "CN"
    assert reaction.smiles == "CNC(C)=O>>CC(=O)O.CN"
示例#10
0
def test_reaction_failure_rdchiral(simple_actions, mocker):
    patched_rchiral_run = mocker.patch("aizynthfinder.chem.rdc.rdchiralRun")
    patched_rchiral_run.side_effect = RuntimeError("Oh no!")
    mol = TreeMolecule(parent=None, smiles="CCCCOc1ccc(CC(=O)N(C)O)cc1")
    reactions, _ = simple_actions(mol)

    products = reactions[0].apply()
    assert not products
示例#11
0
def setup_graphviz_graph():
    mol1 = TreeMolecule(smiles="CCCO", parent=None)
    reaction = RetroReaction(mol=mol1, smarts="")
    graph = image.GraphvizReactionGraph()

    graph.add_molecule(mol1, "green")
    graph.add_reaction(reaction)
    graph.add_edge(mol1, reaction)
    return graph
示例#12
0
def test_retro_reaction(simple_actions):
    mol = TreeMolecule(parent=None, smiles="CCCCOc1ccc(CC(=O)N(C)O)cc1")
    reactions, _ = simple_actions(mol)

    products1 = reactions[0].apply()
    assert products1[0][0].smiles == "CCCCOc1ccc(CC(=O)Cl)cc1"
    assert products1[0][1].smiles == "CNO"

    products2 = reactions[2].apply()
    assert products2 == ()
示例#13
0
def test_expander_top1(mock_expansion_policy):
    expander = AiZynthExpander()
    smi = "CCCCOc1ccc(CC(=O)N(C)O)cc1"
    mock_expansion_policy(TreeMolecule(parent=None, smiles=smi))

    reactions = expander.do_expansion(smi, return_n=1)

    assert len(reactions) == 1
    smiles_list = [mol.smiles for mol in reactions[0][0].reactants[0]]
    assert smiles_list == ["CCCCOc1ccc(CC(=O)Cl)cc1", "CNO"]
示例#14
0
def test_feasible(filter_policy, mock_policy_model, mocker, simple_actions):
    filter_policy.load("dummy.hdf5", "policy1")
    filter_policy.select("policy1")
    mol = TreeMolecule(parent=None, smiles="CCCCOc1ccc(CC(=O)N(C)O)cc1")
    reactions, _ = simple_actions(mol)

    filter_policy._config.filter_cutoff = 0.9
    assert not filter_policy.is_feasible(reactions[0])

    filter_policy._config.filter_cutoff = 0.15
    assert filter_policy.is_feasible(reactions[0])
示例#15
0
def test_expander_filter_policy(mock_expansion_policy, mock_policy_model):
    expander = AiZynthExpander()
    smi = "CCCCOc1ccc(CC(=O)N(C)O)cc1"
    mock_expansion_policy(TreeMolecule(parent=None, smiles=smi))
    expander.filter_policy.load("dummy.hdf5", "policy1")
    expander.filter_policy.select("policy1")

    reactions = expander.do_expansion(smi)

    assert len(reactions) == 2
    assert reactions[0][0].metadata["feasibility"] == 0.2
    assert reactions[1][0].metadata["feasibility"] == 0.2
示例#16
0
def test_create_fixed_reaction():
    smiles = "[C:1](=[O:2])([cH3:3])[N:4][cH3:5]>>Cl[C:1](=[O:2])[cH3:3].[N:4][cH3:5]"
    mol = TreeMolecule(parent=None, smiles="N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F")

    rxn = FixedRetroReaction(mol, smiles=smiles)

    assert rxn.smiles == smiles

    with pytest.raises(ValueError):
        rxn.rd_reaction

    with pytest.raises(NotImplementedError):
        rxn.apply()
示例#17
0
    def create_root(cls, smiles: str, tree: MctsSearchTree,
                    config: Configuration) -> "MctsNode":
        """
        Create a root node for a tree using a SMILES.

        :param smiles: the SMILES representation of the root state
        :param tree: the search tree
        :param config: settings of the tree search algorithm
        :return: the created node
        """
        mol = TreeMolecule(parent=None, transform=0, smiles=smiles)
        state = MctsState(mols=[mol], config=config)
        return MctsNode(state=state, owner=tree, config=config)
示例#18
0
def test_expander_filter(mock_expansion_policy):
    def filter_func(reaction):
        return "CNO" not in [mol.smiles for mol in reaction.reactants[0]]

    expander = AiZynthExpander()
    smi = "CCCCOc1ccc(CC(=O)N(C)O)cc1"
    mock_expansion_policy(TreeMolecule(parent=None, smiles=smi))

    reactions = expander.do_expansion(smi, filter_func=filter_func)

    assert len(reactions) == 1
    smiles_list = [mol.smiles for mol in reactions[0][0].reactants[0]]
    assert smiles_list == ["CCCCBr", "CN(O)C(=O)Cc1ccc(O)cc1"]
示例#19
0
    def do_expansion(
        self,
        smiles: str,
        return_n: int = 5,
        filter_func: Callable[[RetroReaction], bool] = None,
    ) -> List[Tuple[FixedRetroReaction, ...]]:
        """
        Do the expansion of the given molecule returning a list of
        reaction tuples. Each tuple in the list contains reactions
        producing the same reactants. Hence, nested structure of the
        return value is way to group reactions.

        If filter policy is setup, the probability of the reactions are
        added as metadata to the reaction.

        The additional filter functions makes it possible to do customized
        filtering. The callable should take as only argument a `RetroReaction`
        object and return True if the reaction can be kept or False if it should
        be removed.

        :param smiles: the SMILES string of the target molecule
        :param return_n: the length of the return list
        :param filter_func: an additional filter function
        :return: the grouped reactions
        """

        mol = TreeMolecule(parent=None, smiles=smiles)
        actions, _ = self.expansion_policy.get_actions([mol])
        results: Dict[Tuple[str, ...],
                      List[FixedRetroReaction]] = defaultdict(list)
        for action in actions:
            reactants = action.reactants
            if not reactants:
                continue
            if filter_func and not filter_func(action):
                continue
            for name in self.filter_policy.selection or []:
                if hasattr(self.filter_policy[name], "feasibility"):
                    _, feasibility_prob = self.filter_policy[name].feasibility(
                        action)
                    action.metadata["feasibility"] = float(feasibility_prob)
                    break
            action.metadata["expansion_rank"] = len(results) + 1
            unique_key = tuple(sorted(mol.inchi_key for mol in reactants[0]))
            if unique_key not in results and len(results) >= return_n:
                continue
            rxn = next(ReactionTreeFromExpansion(
                action).tree.reactions())  # type: ignore
            results[unique_key].append(rxn)
        return [tuple(reactions) for reactions in results.values()]
def test_add_tree_mol():
    serializer = MoleculeSerializer()
    mol1 = TreeMolecule(parent=None, smiles="CCC", transform=1)
    mol2 = TreeMolecule(smiles="CCO", parent=mol1)

    id_ = serializer[mol2]

    assert id_ == id(mol2)
    assert list(serializer.store.keys()) == [id(mol1), id_]
    assert serializer.store == {
        id_: {
            "smiles": "CCO",
            "class": "TreeMolecule",
            "parent": id(mol1),
            "transform": 2,
        },
        id(mol1): {
            "smiles": "CCC",
            "class": "TreeMolecule",
            "parent": None,
            "transform": 1,
        },
    }
示例#21
0
 def wrapper(parent, key_smiles, child_smiles_list, probs):
     rxn_objs = []
     mol_objs_list = []
     for child_smiles in child_smiles_list:
         if not child_smiles:
             rxn_objs.append(mocked_reaction(parent, None))
             continue
         mol_objs = [
             TreeMolecule(parent=parent, smiles=smiles) for smiles in child_smiles
         ]
         mol_objs_list.append(mol_objs)
         rxn_objs.append(mocked_reaction(parent, mol_objs))
     actions[key_smiles] = rxn_objs, probs
     return mol_objs_list
示例#22
0
def test_retro_reaction_copy(simple_actions):
    mol = TreeMolecule(parent=None, smiles="CCCCOc1ccc(CC(=O)N(C)O)cc1")
    reactions, _ = simple_actions(mol)
    reactions[0].apply()

    copy_ = reactions[0].copy()

    assert copy_.mol.smiles == reactions[0].mol.smiles
    assert len(copy_.reactants[0]) == len(reactions[0].reactants[0])
    assert copy_.index == reactions[0].index

    copy_ = reactions[0].copy(index=2)

    assert copy_.mol.smiles == reactions[0].mol.smiles
    assert len(copy_.reactants[0]) == len(reactions[0].reactants[0])
    assert copy_.index != reactions[0].index
def test_serialize_deserialize_state(default_config):
    mol = TreeMolecule(parent=None, smiles="CCC", transform=1)
    state0 = State([mol], default_config)
    serializer = MoleculeSerializer()

    state_serialized = state0.serialize(serializer)

    assert len(state_serialized["mols"]) == 1
    assert state_serialized["mols"][0] == id(mol)

    deserializer = MoleculeDeserializer(serializer.store)
    state1 = State.from_dict(state_serialized, default_config, deserializer)

    assert len(state1.mols) == 1
    assert state1.mols[0] == state0.mols[0]
    assert state1.in_stock_list == state0.in_stock_list
    assert state1.score == state0.score
示例#24
0
def test_retro_reaction_copy(get_action):
    mol = TreeMolecule(parent=None, smiles="CCCCOc1ccc(CC(=O)N(C)O)cc1")
    reaction = get_action()
    _ = reaction.reactants

    copy_ = reaction.copy()

    assert isinstance(copy_, TemplatedRetroReaction)
    assert copy_.mol.smiles == reaction.mol.smiles
    assert len(copy_.reactants[0]) == len(reaction.reactants[0])
    assert copy_.index == reaction.index

    copy_ = reaction.copy(index=2)

    assert copy_.mol.smiles == reaction.mol.smiles
    assert len(copy_.reactants[0]) == len(reaction.reactants[0])
    assert copy_.index != reaction.index
示例#25
0
def test_expander_defaults(mock_expansion_policy):
    expander = AiZynthExpander()
    smi = "CCCCOc1ccc(CC(=O)N(C)O)cc1"
    mock_expansion_policy(TreeMolecule(parent=None, smiles=smi))

    reactions = expander.do_expansion(smi)

    assert len(reactions) == 2
    assert len(reactions[0]) == 1
    assert len(reactions[1]) == 1

    assert reactions[0][0].mol.smiles == smi
    assert reactions[1][0].mol.smiles == smi
    assert len(reactions[0][0].reactants[0]) == 2
    assert len(reactions[1][0].reactants[0]) == 2
    smi1 = [mol.smiles for mol in reactions[0][0].reactants[0]]
    smi2 = [mol.smiles for mol in reactions[1][0].reactants[0]]
    assert smi1 != smi2
示例#26
0
def test_filter_rejection(default_config, mock_keras_model):
    filter_policy = default_config.filter_policy
    filter_policy.load_from_config(**{"files": {"policy1": "dummy1"}})
    mol = TreeMolecule(parent=None,
                       smiles="CN1CCC(C(=O)c2cccc(NC(=O)c3ccc(F)cc3)c2F)CC1")
    reaction = SmilesBasedRetroReaction(
        mol, reactants_str="CN1CCC(Cl)CC1.N#Cc1cccc(NC(=O)c2ccc(F)cc2)c1F.O")

    with pytest.raises(PolicyException, match="selected"):
        filter_policy(reaction)

    filter_policy.select("policy1")
    filter_policy._config.filter_cutoff = 0.9
    with pytest.raises(RejectionException):
        filter_policy(reaction)

    filter_policy._config.filter_cutoff = 0.15
    filter_policy(reaction)
示例#27
0
def test_reactants_count_rejection(default_config):
    smarts = (
        "([C:3]-[N;H0;D2;+0:2]=[C;H0;D3;+0:1](-[c:4]1:[c:5]:[c:6]:[c:7]:[c:8]:[c:9]:1)-[c;H0;D3;+0:11](:[c:10]):[c:12])>>"
        "(O=[C;H0;D3;+0:1](-[NH;D2;+0:2]-[C:3])-[c:4]1:[c:5]:[c:6]:[c:7]:[c:8]:[c:9]:1.[c:10]:[cH;D2;+0:11]:[c:12])"
    )
    mol = TreeMolecule(parent=None, smiles="c1c2c(ccc1)CCN=C2c3ccccc3")
    rxn1 = TemplatedRetroReaction(mol=mol, smarts=smarts)
    filter = ReactantsCountFilter("dummy", default_config)

    assert len(rxn1.reactants) == 2

    rxn2 = rxn1.copy(index=1)
    if len(rxn1.reactants[0]) == 1:
        rxn1, rxn2 = rxn2, rxn1

    assert filter(rxn2) is None

    with pytest.raises(RejectionException):
        filter(rxn1)
示例#28
0
def get_action():
    smi = "CCCCOc1ccc(CC(=O)N(C)O)cc1"
    mol = TreeMolecule(smiles=smi, parent=None)

    def wrapper(applicable=True):
        if applicable:
            smarts = (
                "([#8:4]-[N;H0;D3;+0:5](-[C;D1;H3:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3])"
                ">>(Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([#8:4]-[NH;D2;+0:5]-[C;D1;H3:6])"
            )
        else:
            smarts = (
                "([C:4]-[N;H0;D3;+0:5](-[C:6])-[C;H0;D3;+0:1](-[C:2])=[O;D1;H03])>>"
                "(O-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3]).([C:4]-[NH;D2;+0:5]-[C:6])"
            )
        return TemplatedRetroReaction(mol,
                                      smarts=smarts,
                                      metadata={"dummy": 1})

    return wrapper
示例#29
0
def test_get_actions_two_policies(default_config,
                                  setup_template_expansion_policy):
    expansion_policy = default_config.expansion_policy
    strategy1, _ = setup_template_expansion_policy("policy1")
    expansion_policy.load(strategy1)
    strategy2, _ = setup_template_expansion_policy("policy2")
    expansion_policy.load(strategy2)
    default_config.additive_expansion = True

    expansion_policy.select(["policy1", "policy2"])
    mols = [TreeMolecule(smiles="CCO", parent=None)]

    actions, priors = expansion_policy.get_actions(mols)

    policy_names = [action.metadata["policy_name"] for action in actions]
    assert policy_names == ["policy1"] * 2 + ["policy2"] * 2
    assert priors == [0.7, 0.2, 0.7, 0.2]

    expansion_policy._config.cutoff_cumulative = 1.0
    actions, priors = expansion_policy.get_actions(mols)

    assert priors == [0.7, 0.2, 0.1, 0.7, 0.2, 0.1]

    expansion_policy._config.cutoff_number = 1
    actions, priors = expansion_policy.get_actions(mols)

    assert priors == [0.7, 0.7]

    default_config.additive_expansion = False
    default_config.cutoff_number = 2

    actions, priors = expansion_policy.get_actions(mols)

    policy_names = [action.metadata["policy_name"] for action in actions]
    assert policy_names == ["policy1", "policy1"]
    assert priors == [0.7, 0.2]
示例#30
0
def setup_graphviz_graph():
    mol1 = TreeMolecule(smiles="CCCO", parent=None)
    reaction = RetroReaction(mol=mol1, smarts="")

    return [mol1], [reaction], [(mol1, reaction)], ["green"]