예제 #1
0
def test_init():
    """Initialization test. """
    for smarts in TEST_BANK:
        graph = SMARTSGraph(smarts)
        atoms = graph.ast.find_data('atom')
        for n, atom in enumerate(atoms):
            assert n in graph.nodes()
예제 #2
0
파일: test_graph.py 프로젝트: iModels/foyer
def test_init():
    """Initialization test. """
    for smarts in TEST_BANK:
        graph = SMARTSGraph(smarts)
        atoms = graph.ast.find_data('atom')
        for n, atom in enumerate(atoms):
            assert n in graph.nodes()
예제 #3
0
 def _rule_match_count(top, typemap, smart, count):
     rule = SMARTSGraph(
         name="test",
         parser=smarts_parser,
         smarts_string=smart,
         typemap=typemap,
     )
     assert len(list(rule.find_matches(top, typemap))) is count
예제 #4
0
 def _rule_match(top, typemap, smart, result):
     rule = SMARTSGraph(
         name="test",
         parser=smarts_parser,
         smarts_string=smart,
         typemap=typemap,
     )
     assert bool(list(rule.find_matches(top, typemap))) is result
예제 #5
0
def test_fused_ring():
    fused = pmd.load_file(get_fn('fused.mol2'), structure=True)
    top, _ = generate_topology(fused)
    rule = SMARTSGraph(name='test', parser=PARSER,
                       smarts_string='[#6]12[#6][#6][#6][#6][#6]1[#6][#6][#6][#6]2')

    match_indices = list(rule.find_matches(top))
    assert 3 in match_indices
    assert 4 in match_indices
    assert len(match_indices) == 2
예제 #6
0
def test_ring_count():
    # Two rings
    fused = pmd.load_file(get_fn('fused.mol2'), structure=True)
    top, _ = generate_topology(fused)
    rule = SMARTSGraph(name='test', parser=PARSER, smarts_string='[#6;R2]')

    match_indices = list(rule.find_matches(top))
    for atom_idx in (3, 4):
        assert atom_idx in match_indices
    assert len(match_indices) == 2

    rule = SMARTSGraph(name='test', parser=PARSER, smarts_string='[#6;R1]')
    match_indices = list(rule.find_matches(top))
    for atom_idx in (0, 1, 2, 5, 6, 7, 8, 9):
        assert atom_idx in match_indices
    assert len(match_indices) == 8

    # One ring
    ring = pmd.load_file(get_fn('ring.mol2'), structure=True)
    top, _ = generate_topology(ring)

    rule = SMARTSGraph(name='test', parser=PARSER, smarts_string='[#6;R1]')
    match_indices = list(rule.find_matches(top))
    for atom_idx in range(6):
        assert atom_idx in match_indices
    assert len(match_indices) == 6
예제 #7
0
def test_fused_ring():
    fused = pmd.load_file(get_fn('fused.mol2'), structure=True)
    top, _ = generate_topology(fused)
    rule = SMARTSGraph(
        name='test',
        parser=PARSER,
        smarts_string='[#6]12[#6][#6][#6][#6][#6]1[#6][#6][#6][#6]2')

    match_indices = list(rule.find_matches(top))
    assert 3 in match_indices
    assert 4 in match_indices
    assert len(match_indices) == 2
예제 #8
0
파일: test_graph.py 프로젝트: iModels/foyer
def test_lazy_cycle_finding():
    mol2 = pmd.load_file(get_fn('ethane.mol2'), structure=True)
    top, _ = generate_topology(mol2)

    rule = SMARTSGraph(smarts_string='[C]')
    list(rule.find_matches(top))
    assert not any([hasattr(a, 'cycles') for a in top.atoms()])

    ring_tokens = ['R1', 'r6']
    for token in ring_tokens:
        rule = SMARTSGraph(smarts_string='[C;{}]'.format(token))
        list(rule.find_matches(top))
        assert all([hasattr(a, 'cycles') for a in top.atoms()])
예제 #9
0
def _update_defs(atomtypes, nonbonded, forcefield):
    def_list = [i.get('def') for i in atomtypes.iterchildren()]
    name_list = [i.get('name') for i in atomtypes.iterchildren()]
    smarts_list = list()
    smarts_parser = forcefield.parser
    for smarts_string, name in zip(def_list, name_list):
        smarts_graph = SMARTSGraph(smarts_string,
                                   parser=smarts_parser,
                                   name=name)
        for atom_expr in nx.get_node_attributes(smarts_graph,
                                                name='atom').values():
            labels = atom_expr.find_data('has_label')
            for label in labels:
                atom_type = label.children[0][1:]
                smarts_list.append(atom_type)
    smarts_list = list(set(smarts_list))
    extra_types = [i for i in smarts_list if i not in name_list]

    for extra in extra_types:
        for i, definition in enumerate(def_list):
            if extra in definition:
                warnings.warn('Removing undefined atom type `{}`'
                              ' from SMARTS string `{}`'.format(
                                  extra, definition))
                extra_edit = '%' + extra
                extra_index = definition.find(extra_edit)
                if definition[extra_index - 1] == ';':
                    new_def = definition.replace(extra_edit + ',', '')
                else:
                    new_def = definition.replace(',' + extra_edit, '')
                atomtypes[i].set('def', new_def)
예제 #10
0
def _update_defs(atomtypes, nonbonded, forcefield):
    def_list = [i.get("def") for i in atomtypes.iterchildren()]
    name_list = [i.get("name") for i in atomtypes.iterchildren()]
    smarts_list = list()
    smarts_parser = forcefield.parser
    for smarts_string, name in zip(def_list, name_list):
        smarts_graph = SMARTSGraph(smarts_string,
                                   parser=smarts_parser,
                                   name=name)
        for atom_expr in nx.get_node_attributes(smarts_graph,
                                                name="atom").values():
            labels = atom_expr.find_data("has_label")
            for label in labels:
                atom_type = label.children[0][1:]
                smarts_list.append(atom_type)
    smarts_list = list(set(smarts_list))
    extra_types = [i for i in smarts_list if i not in name_list]

    for extra in extra_types:
        for i, definition in enumerate(def_list):
            if extra in definition:
                warnings.warn("Removing undefined atom type `{}`"
                              " from SMARTS string `{}`".format(
                                  extra, definition))
                extra_edit = "%" + extra
                extra_index = definition.find(extra_edit)
                if definition[extra_index - 1] == ";":
                    new_def = definition.replace(extra_edit + ",", "")
                else:
                    new_def = definition.replace("," + extra_edit, "")
                atomtypes[i].set("def", new_def)
예제 #11
0
def test_ring_count():
    # Two rings
    fused = pmd.load_file(get_fn('fused.mol2'), structure=True)
    top, _ = generate_topology(fused)
    typemap = {
        atom.index: {
            'whitelist': set(),
            'blacklist': set(),
            'atomtype': None
        }
        for atom in top.atoms()
    }
    rule = SMARTSGraph(name='test',
                       parser=PARSER,
                       smarts_string='[#6;R2]',
                       typemap=typemap)

    match_indices = list(rule.find_matches(top, typemap))
    for atom_idx in (3, 4):
        assert atom_idx in match_indices
    assert len(match_indices) == 2

    rule = SMARTSGraph(name='test',
                       parser=PARSER,
                       smarts_string='[#6;R1]',
                       typemap=typemap)
    match_indices = list(rule.find_matches(top, typemap))
    for atom_idx in (0, 1, 2, 5, 6, 7, 8, 9):
        assert atom_idx in match_indices
    assert len(match_indices) == 8

    # One ring
    ring = pmd.load_file(get_fn('ring.mol2'), structure=True)
    top, _ = generate_topology(ring)
    typemap = {
        atom.index: {
            'whitelist': set(),
            'blacklist': set(),
            'atomtype': None
        }
        for atom in top.atoms()
    }

    rule = SMARTSGraph(name='test',
                       parser=PARSER,
                       smarts_string='[#6;R1]',
                       typemap=typemap)
    match_indices = list(rule.find_matches(top, typemap))
    for atom_idx in range(6):
        assert atom_idx in match_indices
    assert len(match_indices) == 6
예제 #12
0
    def validate_smarts(self, debug):
        """ Check SMARTS definitions for missing or non-parseable """
        missing_smarts = []
        errors = []
        for entry in self.atom_types:
            smarts_string = entry.attrib.get('def')
            if not smarts_string:
                warn("You have empty smart definition(s)", ValidationWarning)
                continue
            name = entry.attrib['name']
            if smarts_string is None:
                missing_smarts.append(name)
                continue
            # make sure smarts string can be parsed
            try:
                self.smarts_parser.parse(smarts_string)
            except lark.ParseError as ex:
                if " col " in ex.args[0]:
                    column = ex.args[0][ex.args[0].find(" col ") + 5:].strip()
                    column = " at character {} of {}".format(
                        column, smarts_string)
                else:
                    column = ""

                malformed = ValidationError(
                    "Malformed SMARTS string{} on line {}".format(
                        column, entry.sourceline), ex, entry.sourceline)
                errors.append(malformed)
                continue

            # make sure referenced labels exist
            smarts_graph = SMARTSGraph(smarts_string,
                                       parser=self.smarts_parser,
                                       name=name,
                                       overrides=entry.attrib.get('overrides'))
            for atom_expr in nx.get_node_attributes(smarts_graph,
                                                    name='atom').values():
                labels = atom_expr.find_data('has_label')
                for label in labels:
                    atom_type = label.children[0][1:]
                    if atom_type not in self.atom_type_names:
                        undefined = ValidationError(
                            "Reference to undefined atomtype '{}' in SMARTS "
                            "string '{}' at line {}".format(
                                atom_type, entry.attrib['def'],
                                entry.sourceline), None, entry.sourceline)
                        errors.append(undefined)
        raise_collected(errors)
        if missing_smarts and debug:
            warn(
                "The following atom types do not have smarts definitions: {}".
                format(', '.join(missing_smarts)), ValidationWarning)
        if missing_smarts and not debug:
            warn(
                "There are {} atom types that are missing a smarts definition. "
                "To view the missing atom types, re-run with debug=True when "
                "applying the forcefield.".format(len(missing_smarts)),
                ValidationWarning)
예제 #13
0
    def test_ring_count(self, smarts_parser):
        # Two rings
        fused = pmd.load_file(get_fn("fused.mol2"), structure=True)
        fused_graph = TopologyGraph.from_parmed(fused)
        typemap = {
            atom.idx: {"whitelist": set(), "blacklist": set(), "atomtype": None}
            for atom in fused.atoms
        }
        rule = SMARTSGraph(
            name="test",
            parser=smarts_parser,
            smarts_string="[#6;R2]",
            typemap=typemap,
        )

        match_indices = list(rule.find_matches(fused_graph, typemap))
        for atom_idx in (3, 4):
            assert atom_idx in match_indices
        assert len(match_indices) == 2

        rule = SMARTSGraph(
            name="test",
            parser=smarts_parser,
            smarts_string="[#6;R1]",
            typemap=typemap,
        )
        match_indices = list(rule.find_matches(fused_graph, typemap))
        for atom_idx in (0, 1, 2, 5, 6, 7, 8, 9):
            assert atom_idx in match_indices
        assert len(match_indices) == 8

        # One ring
        ring = pmd.load_file(get_fn("ring.mol2"), structure=True)
        typemap = {
            atom.idx: {"whitelist": set(), "blacklist": set(), "atomtype": None}
            for atom in ring.atoms
        }

        ring_graph = TopologyGraph.from_parmed(ring)
        rule = SMARTSGraph(
            name="test",
            parser=smarts_parser,
            smarts_string="[#6;R1]",
            typemap=typemap,
        )
        match_indices = list(rule.find_matches(ring_graph, typemap))
        for atom_idx in range(6):
            assert atom_idx in match_indices
        assert len(match_indices) == 6
예제 #14
0
def read_search_mapping(search_mapping_filename, user_mapping_filename,
                        topology):
    """Read the search mapping xml file

    Parameters
    ----------
    search_mapping_filename : str
        Name of xml file containing ordered search parameters
    user_mapping_filename : str
        Name of xml file containing molecules in the system
    topology : mdTraj.Topology
        Topology object (to be expanded)

    """

    # root = ET.fromstring(open(search_mapping_filename).read())
    searchlist = []  # list containing all search values ordered by priority
    searchlist.append("C")
    searchlist.append("CC")
    searchlist.append("CCC")
    # for value in root.findall('value'):
    #     searchlist.append(value.attrib['searchstr'])
    print("{0:s}: {1}".format("Search String", searchlist))

    #root = ET.fromstring(open(user_mapping_filename).read())
    molecules = []
    for molecule in root.findall('molecule'):
        molecules.append(
            molecule.attrib['mol_str'])  #smarts string for molecule
    print("{0:s}: {1}".format("Molecules", molecules))

    parser = SMARTSParser()
    matches = []

    for searchstr in searchlist:
        print(searchstr)
        graph = SMARTSGraph(searchstr, parser=parser)
        i = graph.find_matches(topology)
        matches.append(list(i))

    print(matches)

    return matches
예제 #15
0
    def test_fused_ring(self, smarts_parser):
        mol2 = pmd.load_file(get_fn("fused.mol2"), structure=True)
        mol2_graph = TopologyGraph.from_parmed(mol2)
        typemap = {
            atom.idx: {"whitelist": set(), "blacklist": set(), "atomtype": None}
            for atom in mol2.atoms
        }

        rule = SMARTSGraph(
            name="test",
            parser=smarts_parser,
            smarts_string="[#6]12[#6][#6][#6][#6][#6]1[#6][#6][#6][#6]2",
            typemap=typemap,
        )

        match_indices = list(rule.find_matches(mol2_graph, typemap))
        assert 3 in match_indices
        assert 4 in match_indices
        assert len(match_indices) == 2
예제 #16
0
def test_fused_ring():
    mol2 = pmd.load_file(get_fn('fused.mol2'), structure=True)
    typemap = {
        atom.idx: {
            'whitelist': set(),
            'blacklist': set(),
            'atomtype': None
        }
        for atom in mol2.atoms
    }

    rule = SMARTSGraph(
        name='test',
        parser=PARSER,
        smarts_string='[#6]12[#6][#6][#6][#6][#6]1[#6][#6][#6][#6]2',
        typemap=typemap)

    match_indices = list(rule.find_matches(mol2, typemap))
    assert 3 in match_indices
    assert 4 in match_indices
    assert len(match_indices) == 2
예제 #17
0
def test_ring_count():
    # Two rings
    fused = pmd.load_file(get_fn('fused.mol2'), structure=True)
    top, _ = generate_topology(fused)
    rule = SMARTSGraph(name='test', parser=PARSER,
                       smarts_string='[#6;R2]')

    match_indices = list(rule.find_matches(top))
    for atom_idx in (3, 4):
        assert atom_idx in match_indices
    assert len(match_indices) == 2

    rule = SMARTSGraph(name='test', parser=PARSER,
                       smarts_string='[#6;R1]')
    match_indices = list(rule.find_matches(top))
    for atom_idx in (0, 1, 2, 5, 6, 7, 8, 9):
        assert atom_idx in match_indices
    assert len(match_indices) == 8

    # One ring
    ring = pmd.load_file(get_fn('ring.mol2'), structure=True)
    top, _ = generate_topology(ring)

    rule = SMARTSGraph(name='test', parser=PARSER,
                       smarts_string='[#6;R1]')
    match_indices = list(rule.find_matches(top))
    for atom_idx in range(6):
        assert atom_idx in match_indices
    assert len(match_indices) == 6
예제 #18
0
    def validate_smarts(self):
        missing_smarts = []
        errors = []
        for entry in self.atom_types:
            smarts_string = entry.attrib.get('def')
            name = entry.attrib['name']
            if smarts_string is None:
                missing_smarts.append(name)
                continue
            # make sure smarts string can be parsed
            try:
                self.smarts_parser.parse(smarts_string)
            except ParseError as ex:
                if " col " in ex.args[0]:
                    column = ex.args[0][ex.args[0].find(" col ") + 5:].strip()
                    column = " at character {} of {}".format(
                        column, smarts_string)
                else:
                    column = ""

                malformed = ValidationError(
                    "Malformed SMARTS string{} on line {}".format(
                        column, entry.sourceline), ex, entry.sourceline)
                errors.append(malformed)
                continue

            # make sure referenced labels exist
            smarts_graph = SMARTSGraph(smarts_string,
                                       parser=self.smarts_parser,
                                       name=name,
                                       overrides=entry.attrib.get('overrides'))
            for atom_expr in nx.get_node_attributes(smarts_graph,
                                                    'atom').values():
                labels = atom_expr.select('has_label')
                for label in labels:
                    atom_type = label.tail[0][1:]
                    if atom_type not in self.atom_type_names:
                        undefined = ValidationError(
                            "Reference to undefined atomtype '{}' in SMARTS string '{}' at line {}"
                            .format(atom_type, entry.attrib['def'],
                                    entry.sourceline), None, entry.sourceline)
                        errors.append(undefined)
        if len(errors) > 1:
            raise MultipleValidationError(errors)
        elif len(errors) == 1:
            raise errors[0]
        if missing_smarts:
            warn(
                "The following atom types do not have smarts definitions: {}".
                format(', '.join(missing_smarts)), ValidationWarning)
예제 #19
0
def _load_rules(forcefield):
    """Load atomtyping rules from a forcefield into SMARTSGraphs. """
    rules = dict()
    for rule_name, smarts in forcefield.atomTypeDefinitions.items():
        overrides = forcefield.atomTypeOverrides.get(rule_name)
        if overrides is not None:
            overrides = set(overrides)
        else:
            overrides = set()
        rules[rule_name] = SMARTSGraph(smarts_string=smarts,
                                       parser=forcefield.parser,
                                       name=rule_name,
                                       overrides=overrides)
    return rules
예제 #20
0
파일: validator.py 프로젝트: sallai/foyer
    def validate_smarts(self, ff_tree):
        results = ff_tree.xpath('/ForceField/AtomTypes/Type')
        atom_types = ff_tree.xpath('/ForceField/AtomTypes/Type/@name')

        missing_smarts = []
        for r in results:
            smarts_string = r.attrib.get('def')
            name = r.attrib['name']
            if smarts_string is None:
                missing_smarts.append(name)
                continue
            # make sure smarts string can be parsed
            try:
                self.smarts_parser.parse(smarts_string)
            except ParseError as ex:
                if " col " in ex.args[0]:
                    column = ex.args[0][ex.args[0].find(" col ") + 5:].strip()
                    column = " at character {} of {}".format(
                        column, smarts_string)
                else:
                    column = ""

                raise ValidationError(
                    "Malformed SMARTS string{} on line {}".format(
                        column, r.sourceline), ex, r.sourceline)

            # make sure referenced labels exist
            smarts_graph = SMARTSGraph(smarts_string,
                                       parser=self.smarts_parser,
                                       name=name,
                                       overrides=r.attrib.get('overrides'))
            for atom_expr in nx.get_node_attributes(smarts_graph,
                                                    'atom').values():
                labels = atom_expr.select('has_label')
                for label in labels:
                    atom_type = label.tail[0][1:]
                    if atom_type not in atom_types:
                        raise ValidationError(
                            "Reference to undefined atomtype {} in SMARTS string"
                            " '{}' at line {}".format(atom_type,
                                                      r.attrib['def'],
                                                      r.sourceline), None,
                            r.sourceline)

        warn(
            "The following atom types do not have smarts definitions: {}".
            format(', '.join(missing_smarts)), ValidationWarning)
예제 #21
0
def _load_rules(forcefield, typemap):
    """Load atomtyping rules from a forcefield into SMARTSGraphs. """
    rules = dict()
    for rule_name, smarts in forcefield.atomTypeDefinitions.items():
        if not smarts:  # We want to skip over empty smarts definitions
            continue
        overrides = forcefield.atomTypeOverrides.get(rule_name)
        if overrides is not None:
            overrides = set(overrides)
        else:
            overrides = set()
        rules[rule_name] = SMARTSGraph(smarts_string=smarts,
                                       parser=forcefield.parser,
                                       name=rule_name,
                                       overrides=overrides,
                                       typemap=typemap)
    return rules
예제 #22
0
def test_lazy_cycle_finding():
    mol2 = pmd.load_file(get_fn('ethane.mol2'), structure=True)
    top, _ = generate_topology(mol2)

    rule = SMARTSGraph(smarts_string='[C]')
    list(rule.find_matches(top))
    assert not any([hasattr(a, 'cycles') for a in top.atoms()])

    ring_tokens = ['R1', 'r6']
    for token in ring_tokens:
        rule = SMARTSGraph(smarts_string='[C;{}]'.format(token))
        list(rule.find_matches(top))
        assert all([hasattr(a, 'cycles') for a in top.atoms()])
예제 #23
0
def _load_rules(rules_provider, typemap):
    """Load atomtyping rules from a forcefield into SMARTSGraphs."""
    rules = dict()
    # For every SMARTS string in the force field,
    # create a SMARTSGraph object
    for rule_name, smarts in rules_provider.atomtype_definitions.items():
        if not smarts:  # We want to skip over empty smarts definitions
            continue
        overrides = rules_provider.atomtype_overrides.get(rule_name)
        if overrides is not None:
            overrides = set(overrides)
        else:
            overrides = set()
        rules[rule_name] = SMARTSGraph(
            smarts_string=smarts,
            parser=rules_provider.parser,
            name=rule_name,
            overrides=overrides,
            typemap=typemap,
        )
    return rules
예제 #24
0
def test_lazy_cycle_finding():
    mol2 = pmd.load_file(get_fn('ethane.mol2'), structure=True)
    typemap = {
        atom.idx: {
            'whitelist': set(),
            'blacklist': set(),
            'atomtype': None
        }
        for atom in mol2.atoms
    }

    rule = SMARTSGraph(smarts_string='[C]', typemap=typemap)
    list(rule.find_matches(mol2, typemap))
    assert not any(['cycles' in typemap[a.idx] for a in mol2.atoms])

    ring_tokens = ['R1', 'r6']
    for token in ring_tokens:
        rule = SMARTSGraph(smarts_string='[C;{}]'.format(token),
                           typemap=typemap)
        list(rule.find_matches(mol2, typemap))
        assert all(['cycles' in typemap[a.idx] for a in mol2.atoms])
예제 #25
0
def test_lazy_cycle_finding():
    mol2 = pmd.load_file(get_fn('ethane.mol2'), structure=True)
    top, _ = generate_topology(mol2)
    typemap = {
        atom.index: {
            'whitelist': set(),
            'blacklist': set(),
            'atomtype': None
        }
        for atom in top.atoms()
    }

    rule = SMARTSGraph(smarts_string='[C]', typemap=typemap)
    list(rule.find_matches(top, typemap))
    #assert not any([hasattr(a, 'cycles') for a in top.atoms()])
    assert not any(['cycles' in typemap[a.index] for a in top.atoms()])

    ring_tokens = ['R1', 'r6']
    for token in ring_tokens:
        rule = SMARTSGraph(smarts_string='[C;{}]'.format(token),
                           typemap=typemap)
        list(rule.find_matches(top, typemap))
        #assert all([hasattr(a, 'cycles') for a in top.atoms()])
        assert all(['cycles' in typemap[a.index] for a in top.atoms()])
예제 #26
0
def _rule_match(top, smart, result):
    rule = SMARTSGraph(name='test', parser=PARSER, smarts_string=smart)
    assert bool(list(rule.find_matches(top))) is result
예제 #27
0
def _rule_match_count(top, smart, count):
    rule = SMARTSGraph(name='test', parser=PARSER, smarts_string=smart)
    assert len(list(rule.find_matches(top))) is count
예제 #28
0
def _rule_match(top, smart, result):
    rule = SMARTSGraph(name='test', parser=PARSER, smarts_string=smart)
    assert bool(list(rule.find_matches(top))) is result
예제 #29
0
def _rule_match_count(top, smart, count):
    rule = SMARTSGraph(name='test', parser=PARSER, smarts_string=smart)
    assert len(list(rule.find_matches(top))) is count