def test_coverage_filter_allowed():
    """
    Make sure the coverage filter removes the correct molecules.
    """

    coverage_filter = workflow_components.CoverageFilter(allowed_ids={"b83"})

    mols = get_stereoisomers()

    # we have to remove duplicated records
    # remove duplicates from the set
    molecule_container = get_container(mols)
    result = coverage_filter.apply(molecule_container.molecules,
                                   processors=1,
                                   toolkit_registry=GLOBAL_TOOLKIT_REGISTRY)

    forcefield = ForceField("openff_unconstrained-1.0.0.offxml")
    # now see if any molecules do not have b83
    parameters_by_id = {}
    for molecule in result.molecules:
        labels = forcefield.label_molecules(molecule.to_topology())[0]
        covered_types = set([
            label.id for types in labels.values() for label in types.values()
        ])
        # now store the smiles under the ids
        for parameter in covered_types:
            parameters_by_id.setdefault(parameter,
                                        []).append(molecule.to_smiles())

    expected = parameters_by_id["b83"]
    for molecule in result.molecules:
        assert molecule.to_smiles() in expected
        assert "dihedrals" not in molecule.properties
def checkTorsion(molList, ff_name):
    """
    Take mollist and check if the molecules in a list match a specific torsion id

        Parameters
        ----------
        molList : List of objects
            List of oemols with datatags generated in genData function

        Returns
        -------
        molList : list of objects
            List of oemol objects that have a datatag "IDMatch" that contain the torsion id
            involved in the QCA torsion drive
    """

    matches = []
    count = 0
    mols = []
    for mol in molList:
        molecule = Molecule.from_mapped_smiles(mol.GetData("cmiles"))
        topology = Topology.from_molecules(molecule)
        # Let's label using the Parsley force field
        forcefield = ForceField(ff_name, allow_cosmetic_attributes=True)
        # Run the molecule labeling
        molecule_force_list = forcefield.label_molecules(topology)
        params = []
        # Print out a formatted description of the torsion parameters applied to this molecule
        for mol_idx, mol_forces in enumerate(molecule_force_list):
            # print(f'Forces for molecule {mol_idx}')
            for force_tag, force_dict in mol_forces.items():
                if force_tag == "ProperTorsions":
                    for (atom_indices, parameter) in force_dict.items():
                        params.append(parameter.id)
                        if atom_indices == mol.GetData("TDindices") or tuple(
                            reversed(atom_indices)
                        ) == mol.GetData("TDindices"):
                            count += 1
                            mol.SetData("IDMatch", parameter.id)
                            mols.append(mol)
    print(
        "Out of "
        + str(len(molList))
        + " molecules, "
        + str(count)
        + " were processed with checkTorsion()"
    )

    return mols
def smirnoff_coverage(molecules: Iterable[Molecule],
                      force_field: ForceField,
                      verbose: bool = False) -> Dict[str, Dict[str, int]]:
    """Returns a summary of how many of the specified molecules would be assigned each
    of the parameters in a force field.

    Notes:
        * Parameters which would not be assigned to any molecules of the specified
          molecules will not be included in the returned summary.

    Args:
        molecules: The molecules to generate a coverage report for.
        force_field: The force field containing the parameters to summarize.
        verbose: If true a progress bar will be shown on screen.

    Returns:
        A dictionary of the form ``coverage[handler_name][parameter_smirks] = count``
        which stores the number of molecules that would be assigned to each parameter.
    """

    molecules = [*molecules]

    coverage = defaultdict(lambda: defaultdict(set))

    for molecule in tqdm(
            molecules,
            total=len(molecules),
            ncols=80,
            disable=not verbose,
    ):

        full_labels = force_field.label_molecules(molecule.to_topology())[0]

        for handler_name, parameter_labels in full_labels.items():
            for parameter in parameter_labels.values():
                coverage[handler_name][parameter.smirks].add(
                    molecule.to_smiles(mapped=False, isomeric=False))

    # Convert the defaultdict objects back into ordinary dicts so that users get
    # KeyError exceptions when trying to access missing handlers / smirks.
    return {
        handler_name: {smirks: len(count)
                       for smirks, count in counts.items()}
        for handler_name, counts in coverage.items()
    }
def single_molecule_coverage(molecule: Molecule, forcefield: ForceField):
    #-> Dict[str, Dict[str, int]], List[Molecule], List[Molecule, Exception]
    """
    For a single molecule generate a coverage report and try to build an openmm system this will also highlight any missing parameters and dificulties with charging the molecule.

    Parameters
    ----------
    molecule: The openff-toolkit molecule object for which the report should be generated.
    ff: The openff-toolkit typing engine that should be used to check coverage and build an openmm system.

    Returns
    -------
    report: dict
        A dictionary of the coverage report 
    e: Exception or None. 
        The exception raised in this step, if any. 
        If not None, it should be assumed that coverage is invalid.
    """

    coverage = {
        "Angles": {},
        "Bonds": {},
        "ProperTorsions": {},
        "ImproperTorsions": {},
        "vdW": {}
    }
    coverage["molecule"] = molecule
    try:
        labels = forcefield.label_molecules(molecule.to_topology())[0]
        for param_type, params in labels.items():
            for param in params.values():
                if param.id not in coverage[param_type]:
                    coverage[param_type][param.id] = 1
                else:
                    coverage[param_type][param.id] += 1
        # now generate a system this will catch any missing parameters
        # and molecules that can not be charged
        _ = forcefield.create_openmm_system(molecule.to_topology())
        return coverage, None
    except Exception as e:
        return coverage, e
Example #5
0
def test_coverage_filter():
    """
    Make sure the coverage filter removes the correct molecules.
    """
    from openff.toolkit.typing.engines.smirnoff import ForceField

    coverage_filter = workflow_components.CoverageFilter()
    coverage_filter.allowed_ids = ["b83"]
    coverage_filter.filtered_ids = ["b87"]

    mols = get_stereoisomers()

    # we have to remove duplicated records
    # remove duplicates from the set
    molecule_container = get_container(mols)
    result = coverage_filter.apply(molecule_container.molecules, processors=1)

    forcefield = ForceField("openff_unconstrained-1.0.0.offxml")
    # now see if any molecules do not have b83
    parameters_by_id = {}
    for molecule in result.molecules:
        labels = forcefield.label_molecules(molecule.to_topology())[0]
        covered_types = set(
            [label.id for types in labels.values() for label in types.values()]
        )
        # now store the smiles under the ids
        for parameter in covered_types:
            parameters_by_id.setdefault(parameter, []).append(molecule.to_smiles())

    expected = parameters_by_id["b83"]
    for molecule in result.molecules:
        assert molecule.to_smiles() in expected
        assert "dihedrals" not in molecule.properties

    # we now need to check that the molecules passed contain only the allowed atoms
    # do this by running the component again
    result2 = coverage_filter.apply(result.molecules, processors=1)
    assert result2.n_filtered == 0
    assert result.n_molecules == result.n_molecules
def test_coverage_filter_remove():
    """
    Make sure we can remove molecules which hit unwanted ids.
    """

    coverage_filter = workflow_components.CoverageFilter(filtered_ids={"b87"})
    mols = get_stereoisomers()

    # we have to remove duplicated records
    # remove duplicates from the set
    molecule_container = get_container(mols)
    result = coverage_filter.apply(molecule_container.molecules,
                                   processors=1,
                                   toolkit_registry=GLOBAL_TOOLKIT_REGISTRY)

    forcefield = ForceField("openff_unconstrained-1.0.0.offxml")
    # now see if any molecules do not have b83
    for molecule in result.molecules:
        labels = forcefield.label_molecules(molecule.to_topology())[0]
        covered_types = set([
            label.id for types in labels.values() for label in types.values()
        ])
        assert "b87" not in covered_types
Example #7
0
    def get_parameters_from_forcefield(self, forcefield, molecule):
        """
        It returns the parameters that are obtained with the supplied
        forcefield for a certain peleffy's molecule.

        Parameters
        ----------
        forcefield : str or an openforcefield.typing.engines.smirnoff.ForceField
                     object
            The forcefield from which the parameters will be obtained
        molecule : an peleffy.topology.Molecule
            The peleffy's Molecule object

        Returns
        -------
        openforcefield_parameters : dict
            The OpenFF parameters stored in a dict keyed by parameter type
        """
        from openff.toolkit.typing.engines.smirnoff import ForceField
        from openff.toolkit.topology import Topology

        off_molecule = molecule.off_molecule
        topology = Topology.from_molecules([off_molecule])

        if isinstance(forcefield, str):
            forcefield = ForceField(forcefield)
        elif isinstance(forcefield, ForceField):
            pass
        else:
            raise Exception('Invalid forcefield type')

        molecule_parameters_list = forcefield.label_molecules(topology)

        assert len(molecule_parameters_list) == 1, 'A single molecule is ' \
            'expected'

        return molecule_parameters_list[0]
def smirnoff_torsion_coverage(
    molecules: Iterable[Tuple[Molecule, Tuple[int, int, int, int]]],
    force_field: ForceField,
    verbose: bool = False,
):
    """Returns a summary of how many unique molecules within this collection
    would be assigned each of the parameters in a force field.

    Notes:
        * Parameters which would not be assigned to any molecules in the collection
          will not be included in the returned summary.

    Args:
        molecules: The molecules and associated torsion (as defined by a quartet of
            atom indices) to generate a coverage report for.
        force_field: The force field containing the parameters to summarize.
        verbose: If true a progress bar will be shown on screen.

    Returns:
        A dictionary of the form ``coverage[handler_name][parameter_smirks] = count``
        which stores the number of unique torsions within this collection that
        would be assigned to each parameter.
    """

    molecules = [*molecules]

    labelled_molecules = {}

    # Only label each unique molecule once as this is a pretty slow operation.
    for molecule, _ in tqdm(
            molecules,
            total=len(molecules),
            ncols=80,
            desc="Assigning Parameters",
            disable=not verbose,
    ):

        smiles = molecule.to_smiles(isomeric=False, mapped=False)

        if smiles in labelled_molecules:
            continue

        labelled_molecules[smiles] = force_field.label_molecules(
            molecule.to_topology())[0]

    coverage = defaultdict(lambda: defaultdict(set))

    for molecule, dihedral in tqdm(
            molecules,
            total=len(molecules),
            ncols=80,
            desc="Summarising",
            disable=not verbose,
    ):

        smiles = molecule.to_smiles(isomeric=False, mapped=False)
        full_labels = labelled_molecules[smiles]

        tagged_molecule = copy.deepcopy(molecule)
        tagged_molecule.properties["atom_map"] = {
            j: i + 1
            for i, j in enumerate(dihedral)
        }
        tagged_smiles = tagged_molecule.to_smiles(isomeric=False, mapped=True)

        dihedral_indices = {*dihedral[1:3]}

        for handler_name, parameter_labels in full_labels.items():
            for indices, parameter in parameter_labels.items():

                if handler_name not in {
                        "Bonds",
                        "Angles",
                        "ProperTorsions",
                        "ImproperTorsions",
                }:
                    continue

                consecutive_pairs = [{*pair}
                                     for pair in zip(indices, indices[1:])]

                # Only count angles and bonds involving the central dihedral bond or
                # dihedrals involving the central dihedral bond.
                if (handler_name in {"Bonds", "Angles", "ImproperTorsions"}
                        and all(pair != dihedral_indices
                                for pair in consecutive_pairs)
                    ) or (handler_name == "ProperTorsions"
                          and consecutive_pairs[1] != dihedral_indices):
                    continue

                coverage[handler_name][parameter.smirks].add(tagged_smiles)

    return {
        handler_name:
        {smirks: len(smiles)
         for smirks, smiles in smiles.items()}
        for handler_name, smiles in coverage.items()
    }
Example #9
0
class SMIRNOFF(OpenMM):
    """ Derived from Engine object for carrying out OpenMM calculations that use the SMIRNOFF force field. """
    def __init__(self, name="openmm", **kwargs):
        self.valkwd = [
            'ffxml', 'pdb', 'mol2', 'platname', 'precision', 'mmopts',
            'vsite_bonds', 'implicit_solvent', 'restrain_k', 'freeze_atoms'
        ]
        if not toolkit_import_success:
            warn_once(
                "Note: Failed to import the OpenFF Toolkit - SMIRNOFF Engine will not work. "
            )
        super(SMIRNOFF, self).__init__(name=name, **kwargs)

    def readsrc(self, **kwargs):
        """
        SMIRNOFF simulations always require the following passed in via kwargs:

        Parameters
        ----------
        pdb : string
            Name of a .pdb file containing the topology of the system
        mol2 : list
            A list of .mol2 file names containing the molecule/residue templates of the system

        Also provide 1 of the following, containing the coordinates to be used:
        mol : Molecule
            forcebalance.Molecule object
        coords : string
            Name of a file (readable by forcebalance.Molecule)
            This could be the same as the pdb argument from above.
        """

        pdbfnm = None
        # Determine the PDB file name.
        if 'pdb' in kwargs and os.path.exists(kwargs['pdb']):
            # Case 1. The PDB file name is provided explicitly
            pdbfnm = kwargs['pdb']
            if not os.path.exists(pdbfnm):
                logger.error("%s specified but doesn't exist\n" % pdbfnm)
                raise RuntimeError

        if 'mol' in kwargs:
            self.mol = kwargs['mol']
        elif 'coords' in kwargs:
            if not os.path.exists(kwargs['coords']):
                logger.error("%s specified but doesn't exist\n" %
                             kwargs['coords'])
                raise RuntimeError
            self.mol = Molecule(kwargs['coords'])
        else:
            logger.error(
                'Must provide either a molecule object or coordinate file.\n')
            raise RuntimeError

        # Here we cannot distinguish the .mol2 files linked by the target
        # vs. the .mol2 files to be provided by the force field.
        # But we can assume that these files should exist when this function is called.

        self.mol2 = kwargs.get('mol2')
        if self.mol2:
            for fnm in self.mol2:
                if not os.path.exists(fnm):
                    if hasattr(self, 'FF') and fnm in self.FF.fnms: continue
                    logger.error("%s doesn't exist" % fnm)
                    raise RuntimeError
        else:
            logger.error("Must provide a list of .mol2 files.\n")

        if pdbfnm is not None:
            self.abspdb = os.path.abspath(pdbfnm)
            mpdb = Molecule(pdbfnm)
            for i in ["chain", "atomname", "resid", "resname", "elem"]:
                self.mol.Data[i] = mpdb.Data[i]

        # Store a separate copy of the molecule for reference restraint positions.
        self.ref_mol = deepcopy(self.mol)

    @staticmethod
    def _openff_to_openmm_topology(openff_topology):
        """Convert an OpenFF topology to an OpenMM topology. Currently this requires
        manually adding the v-sites as OpenFF currently does not."""

        from openff.toolkit.topology import TopologyAtom

        openmm_topology = openff_topology.to_openmm()

        # Return the topology if the number of OpenMM particles matches the number
        # expected by the OpenFF topology. This may happen if there are no virtual sites
        # in the system OR if a new version of the the OpenFF toolkit includes virtual
        # sites in the OpenMM topology it returns.
        if openmm_topology.getNumAtoms(
        ) == openff_topology.n_topology_particles:
            return openmm_topology

        openmm_chain = openmm_topology.addChain()
        openmm_residue = openmm_topology.addResidue("", chain=openmm_chain)

        for particle in openff_topology.topology_particles:

            if isinstance(particle, TopologyAtom):
                continue

            openmm_topology.addAtom(particle.virtual_site.name,
                                    app.Element.getByMass(0), openmm_residue)

        return openmm_topology

    def prepare(self, pbc=False, mmopts={}, **kwargs):
        """
        Prepare the calculation.  Note that we don't create the
        Simulation object yet, because that may depend on MD
        integrator parameters, thermostat, barostat etc.

        This is mostly copied and modified from openmmio.py's OpenMM.prepare(),
        but we are calling ForceField() from the OpenFF toolkit and ignoring
        AMOEBA stuff.
        """

        if hasattr(self, 'abspdb'):
            self.pdb = PDBFile(self.abspdb)
        else:
            pdb1 = "%s-1.pdb" % os.path.splitext(os.path.basename(
                self.mol.fnm))[0]
            self.mol[0].write(pdb1)
            self.pdb = PDBFile(pdb1)
            os.unlink(pdb1)

        # Create the OpenFF ForceField object.
        if hasattr(self, 'FF'):
            self.offxml = [self.FF.offxml]
            self.forcefield = self.FF.openff_forcefield
        else:
            self.offxml = listfiles(kwargs.get('offxml'), 'offxml', err=True)
            self.forcefield = OpenFF_ForceField(*self.offxml,
                                                load_plugins=True)

        ## Load mol2 files for smirnoff topology
        openff_mols = []
        for fnm in self.mol2:
            try:
                mol = OffMolecule.from_file(fnm)
            except Exception as e:
                logger.error("Error when loading %s" % fnm)
                raise e
            openff_mols.append(mol)
        self.off_topology = OffTopology.from_openmm(
            self.pdb.topology, unique_molecules=openff_mols)

        ## OpenMM options for setting up the System.
        self.mmopts = dict(mmopts)

        ## Specify frozen atoms and restraint force constant
        if 'restrain_k' in kwargs:
            self.restrain_k = kwargs['restrain_k']
        if 'freeze_atoms' in kwargs:
            self.freeze_atoms = kwargs['freeze_atoms'][:]

        ## Set system options from ForceBalance force field options.
        fftmp = False
        if hasattr(self, 'FF'):
            self.mmopts['rigidWater'] = self.FF.rigid_water
            if not all([os.path.exists(f) for f in self.FF.fnms]):
                # If the parameter files don't already exist, create them for the purpose of
                # preparing the engine, but then delete them afterward.
                fftmp = True
                self.FF.make(np.zeros(self.FF.np))

        ## Set system options from periodic boundary conditions.
        self.pbc = pbc
        ## print warning for 'nonbonded_cutoff' keywords
        if 'nonbonded_cutoff' in kwargs:
            logger.warning(
                "nonbonded_cutoff keyword ignored because it's set in the offxml file\n"
            )

        # Apply the FF parameters to the system. Currently this is the only way to
        # determine if the FF will apply virtual sites to the system.
        _, openff_topology = self.forcefield.create_openmm_system(
            self.off_topology, return_topology=True)

        ## Generate OpenMM-compatible positions
        self.xyz_omms = []

        for I in range(len(self.mol)):
            xyz = self.mol.xyzs[I]
            xyz_omm = ([Vec3(i[0], i[1], i[2]) for i in xyz]
                       # Add placeholder positions for an v-sites.
                       + [Vec3(0.0, 0.0, 0.0)] *
                       openff_topology.n_topology_virtual_sites) * angstrom

            if self.pbc:
                # Obtain the periodic box
                if self.mol.boxes[I].alpha != 90.0 or self.mol.boxes[
                        I].beta != 90.0 or self.mol.boxes[I].gamma != 90.0:
                    logger.error('OpenMM cannot handle nonorthogonal boxes.\n')
                    raise RuntimeError
                box_omm = np.diag([
                    self.mol.boxes[I].a, self.mol.boxes[I].b,
                    self.mol.boxes[I].c
                ]) * angstrom
            else:
                box_omm = None
            # Finally append it to list.
            self.xyz_omms.append((xyz_omm, box_omm))

        # used in create_simulation()
        openmm_topology = SMIRNOFF._openff_to_openmm_topology(openff_topology)
        openmm_positions = (
            self.pdb.positions.value_in_unit(angstrom) +
            # Add placeholder positions for an v-sites.
            [Vec3(0.0, 0.0, 0.0)] *
            openff_topology.n_topology_virtual_sites) * angstrom

        self.mod = Modeller(openmm_topology, openmm_positions)

        ## Build a topology and atom lists.
        Top = self.mod.getTopology()
        Atoms = list(Top.atoms())

        # vss = [(i, [system.getVirtualSite(i).getParticle(j) for j in range(system.getVirtualSite(i).getNumParticles())]) \
        #            for i in range(system.getNumParticles()) if system.isVirtualSite(i)]
        self.AtomLists = defaultdict(list)
        self.AtomLists['Mass'] = [
            a.element.mass.value_in_unit(dalton)
            if a.element is not None else 0 for a in Atoms
        ]
        self.AtomLists['ParticleType'] = [
            'A' if m >= 1.0 else 'D' for m in self.AtomLists['Mass']
        ]
        self.AtomLists['ResidueNumber'] = [a.residue.index for a in Atoms]
        self.AtomMask = [a == 'A' for a in self.AtomLists['ParticleType']]
        self.realAtomIdxs = [
            i for i, a in enumerate(self.AtomMask) if a is True
        ]
        if hasattr(self, 'FF') and fftmp:
            for f in self.FF.fnms:
                os.unlink(f)

    def update_simulation(self, **kwargs):
        """
        Create the simulation object, or update the force field
        parameters in the existing simulation object.  This should be
        run when we write a new force field XML file.
        """
        if len(kwargs) > 0:
            self.simkwargs = kwargs

        # Because self.forcefield is being updated in forcebalance.forcefield.FF.make()
        # there is no longer a need to create a new force field object here.
        try:
            self.system, openff_topology = self.forcefield.create_openmm_system(
                self.off_topology, return_topology=True)
        except Exception as error:
            logger.error("Error when creating system for %s" % self.mol2)
            raise error
        # Commenting out all virtual site stuff for now.
        # self.vsinfo = PrepareVirtualSites(self.system)
        self.nbcharges = np.zeros(self.system.getNumParticles())

        #----
        # If the virtual site parameters have changed,
        # the simulation object must be remade.
        #----
        # vsprm = GetVirtualSiteParameters(self.system)
        # if hasattr(self,'vsprm') and len(self.vsprm) > 0 and np.max(np.abs(vsprm - self.vsprm)) != 0.0:
        #     if hasattr(self, 'simulation'):
        #         delattr(self, 'simulation')
        # self.vsprm = vsprm.copy()

        if openff_topology.n_topology_virtual_sites > 0:
            # For now always assume that the v-sites have changed. This is currently
            # needed as the FB checks don't support the ``LocalCoordinatesSite`` based
            # virtual sites that OpenFF uses.
            if hasattr(self, 'simulation'):
                delattr(self, 'simulation')

        if hasattr(self, 'simulation'):
            UpdateSimulationParameters(self.system, self.simulation)
        else:
            self.create_simulation(**self.simkwargs)

    def _update_positions(self, X1, disable_vsite):
        # X1 is a numpy ndarray not vec3

        if disable_vsite:
            super(SMIRNOFF, self)._update_positions(X1, disable_vsite)
            return

        n_v_sites = (self.mod.getTopology().getNumAtoms() -
                     self.pdb.topology.getNumAtoms())

        # Add placeholder positions for an v-sites.
        if isinstance(X1, np.ndarray):
            X1 = numpy.vstack([X1, np.zeros((n_v_sites, 3))]) * angstrom
        else:
            X1 = (X1 + [Vec3(0.0, 0.0, 0.0)] * n_v_sites) * angstrom

        self.simulation.context.setPositions(X1)
        self.simulation.context.computeVirtualSites()

    def interaction_energy(self, fraga, fragb):
        """
        Calculate the interaction energy for two fragments.
        Because this creates two new objects and requires passing in the mol2 argument,
        the codes are copied and modified from the OpenMM class.
        """

        self.update_simulation()

        if self.name == 'A' or self.name == 'B':
            logger.error("Don't name the engine A or B!\n")
            raise RuntimeError

        # Create two subengines.
        if hasattr(self, 'target'):
            if not hasattr(self, 'A'):
                self.A = SMIRNOFF(name="A",
                                  mol=self.mol.atom_select(fraga),
                                  mol2=self.mol2,
                                  target=self.target)
            if not hasattr(self, 'B'):
                self.B = SMIRNOFF(name="B",
                                  mol=self.mol.atom_select(fragb),
                                  mol2=self.mol2,
                                  target=self.target)
        else:
            if not hasattr(self, 'A'):
                self.A = SMIRNOFF(name="A", mol=self.mol.atom_select(fraga), mol2=self.mol2, platname=self.platname, \
                                  precision=self.precision, offxml=self.offxml, mmopts=self.mmopts)
            if not hasattr(self, 'B'):
                self.B = SMIRNOFF(name="B", mol=self.mol.atom_select(fragb), mol2=self.mol2, platname=self.platname, \
                                  precision=self.precision, offxml=self.offxml, mmopts=self.mmopts)

        # Interaction energy needs to be in kcal/mol.
        D = self.energy()
        A = self.A.energy()
        B = self.B.energy()

        return (D - A - B) / 4.184

    def get_smirks_counter(self):
        """Get a counter for the time of appreance of each SMIRKS"""
        smirks_counter = Counter()
        molecule_force_list = self.forcefield.label_molecules(
            self.off_topology)
        for mol_idx, mol_forces in enumerate(molecule_force_list):
            for force_tag, force_dict in mol_forces.items():
                # e.g. force_tag = 'Bonds'
                for parameters in force_dict.values():

                    if not isinstance(parameters, list):
                        parameters = [parameters]

                    for parameter in parameters:
                        smirks_counter[parameter.smirks] += 1

        return smirks_counter
Example #10
0
class SMIRNOFF(OpenMM):
    """ Derived from Engine object for carrying out OpenMM calculations that use the SMIRNOFF force field. """
    def __init__(self, name="openmm", **kwargs):
        self.valkwd = [
            'ffxml', 'pdb', 'mol2', 'platname', 'precision', 'mmopts',
            'vsite_bonds', 'implicit_solvent', 'restrain_k', 'freeze_atoms'
        ]
        if not toolkit_import_success:
            warn_once(
                "Note: Failed to import the OpenFF Toolkit - SMIRNOFF Engine will not work. "
            )
        super(SMIRNOFF, self).__init__(name=name, **kwargs)

    def readsrc(self, **kwargs):
        """
        SMIRNOFF simulations always require the following passed in via kwargs:

        Parameters
        ----------
        pdb : string
            Name of a .pdb file containing the topology of the system
        mol2 : list
            A list of .mol2 file names containing the molecule/residue templates of the system

        Also provide 1 of the following, containing the coordinates to be used:
        mol : Molecule
            forcebalance.Molecule object
        coords : string
            Name of a file (readable by forcebalance.Molecule)
            This could be the same as the pdb argument from above.
        """

        pdbfnm = kwargs.get('pdb')
        # Determine the PDB file name.
        if not pdbfnm:
            raise RuntimeError('Name of PDB file not provided.')
        elif not os.path.exists(pdbfnm):
            logger.error("%s specified but doesn't exist\n" % pdbfnm)
            raise RuntimeError

        if 'mol' in kwargs:
            self.mol = kwargs['mol']
        elif 'coords' in kwargs:
            if not os.path.exists(kwargs['coords']):
                logger.error("%s specified but doesn't exist\n" %
                             kwargs['coords'])
                raise RuntimeError
            self.mol = Molecule(kwargs['coords'])
        else:
            logger.error(
                'Must provide either a molecule object or coordinate file.\n')
            raise RuntimeError

        # Here we cannot distinguish the .mol2 files linked by the target
        # vs. the .mol2 files to be provided by the force field.
        # But we can assume that these files should exist when this function is called.

        self.mol2 = kwargs.get('mol2')
        if self.mol2:
            for fnm in self.mol2:
                if not os.path.exists(fnm):
                    if hasattr(self, 'FF') and fnm in self.FF.fnms: continue
                    logger.error("%s doesn't exist" % fnm)
                    raise RuntimeError
        else:
            logger.error("Must provide a list of .mol2 files.\n")

        self.abspdb = os.path.abspath(pdbfnm)
        mpdb = Molecule(pdbfnm)
        for i in ["chain", "atomname", "resid", "resname", "elem"]:
            self.mol.Data[i] = mpdb.Data[i]

        # Store a separate copy of the molecule for reference restraint positions.
        self.ref_mol = deepcopy(self.mol)

    def prepare(self, pbc=False, mmopts={}, **kwargs):
        """
        Prepare the calculation.  Note that we don't create the
        Simulation object yet, because that may depend on MD
        integrator parameters, thermostat, barostat etc.

        This is mostly copied and modified from openmmio.py's OpenMM.prepare(),
        but we are calling ForceField() from the OpenFF toolkit and ignoring
        AMOEBA stuff.
        """
        self.pdb = PDBFile(self.abspdb)

        # Create the OpenFF ForceField object.
        if hasattr(self, 'FF'):
            self.offxml = [self.FF.offxml]
            self.forcefield = self.FF.openff_forcefield
        else:
            self.offxml = listfiles(kwargs.get('offxml'), 'offxml', err=True)
            self.forcefield = OpenFF_ForceField(*self.offxml)

        ## Load mol2 files for smirnoff topology
        openff_mols = []
        for fnm in self.mol2:
            try:
                mol = OffMolecule.from_file(fnm)
            except Exception as e:
                logger.error("Error when loading %s" % fnm)
                raise e
            openff_mols.append(mol)
        self.off_topology = OffTopology.from_openmm(
            self.pdb.topology, unique_molecules=openff_mols)

        # used in create_simulation()
        self.mod = Modeller(self.pdb.topology, self.pdb.positions)

        ## OpenMM options for setting up the System.
        self.mmopts = dict(mmopts)

        ## Specify frozen atoms and restraint force constant
        if 'restrain_k' in kwargs:
            self.restrain_k = kwargs['restrain_k']
        if 'freeze_atoms' in kwargs:
            self.freeze_atoms = kwargs['freeze_atoms'][:]

        ## Set system options from ForceBalance force field options.
        fftmp = False
        if hasattr(self, 'FF'):
            self.mmopts['rigidWater'] = self.FF.rigid_water
            if not all([os.path.exists(f) for f in self.FF.fnms]):
                # If the parameter files don't already exist, create them for the purpose of
                # preparing the engine, but then delete them afterward.
                fftmp = True
                self.FF.make(np.zeros(self.FF.np))

        ## Set system options from periodic boundary conditions.
        self.pbc = pbc
        ## print warning for 'nonbonded_cutoff' keywords
        if 'nonbonded_cutoff' in kwargs:
            logger.warning(
                "nonbonded_cutoff keyword ignored because it's set in the offxml file\n"
            )

        ## Generate OpenMM-compatible positions
        self.xyz_omms = []
        for I in range(len(self.mol)):
            position = self.mol.xyzs[I] * angstrom
            # xyz_omm = [Vec3(i[0],i[1],i[2]) for i in xyz]*angstrom
            # An extra step with adding virtual particles
            # mod = Modeller(self.pdb.topology, xyz_omm)
            # LPW commenting out because we don't have virtual sites yet.
            # mod.addExtraParticles(self.forcefield)
            if self.pbc:
                # Obtain the periodic box
                if self.mol.boxes[I].alpha != 90.0 or self.mol.boxes[
                        I].beta != 90.0 or self.mol.boxes[I].gamma != 90.0:
                    logger.error('OpenMM cannot handle nonorthogonal boxes.\n')
                    raise RuntimeError
                box_omm = np.diag([
                    self.mol.boxes[I].a, self.mol.boxes[I].b,
                    self.mol.boxes[I].c
                ]) * angstrom
            else:
                box_omm = None
            # Finally append it to list.
            self.xyz_omms.append((position, box_omm))

        ## Build a topology and atom lists.
        Top = self.pdb.topology
        Atoms = list(Top.atoms())
        Bonds = [(a.index, b.index) for a, b in list(Top.bonds())]

        # vss = [(i, [system.getVirtualSite(i).getParticle(j) for j in range(system.getVirtualSite(i).getNumParticles())]) \
        #            for i in range(system.getNumParticles()) if system.isVirtualSite(i)]
        self.AtomLists = defaultdict(list)
        self.AtomLists['Mass'] = [
            a.element.mass.value_in_unit(dalton)
            if a.element is not None else 0 for a in Atoms
        ]
        self.AtomLists['ParticleType'] = [
            'A' if m >= 1.0 else 'D' for m in self.AtomLists['Mass']
        ]
        self.AtomLists['ResidueNumber'] = [a.residue.index for a in Atoms]
        self.AtomMask = [a == 'A' for a in self.AtomLists['ParticleType']]
        self.realAtomIdxs = [
            i for i, a in enumerate(self.AtomMask) if a is True
        ]
        if hasattr(self, 'FF') and fftmp:
            for f in self.FF.fnms:
                os.unlink(f)

    def update_simulation(self, **kwargs):
        """
        Create the simulation object, or update the force field
        parameters in the existing simulation object.  This should be
        run when we write a new force field XML file.
        """
        if len(kwargs) > 0:
            self.simkwargs = kwargs

        # Because self.forcefield is being updated in forcebalance.forcefield.FF.make()
        # there is no longer a need to create a new force field object here.
        try:
            self.system = self.forcefield.create_openmm_system(
                self.off_topology)
        except Exception as error:
            logger.error("Error when creating system for %s" % self.mol2)
            raise error
        # Commenting out all virtual site stuff for now.
        # self.vsinfo = PrepareVirtualSites(self.system)
        self.nbcharges = np.zeros(self.system.getNumParticles())

        #----
        # If the virtual site parameters have changed,
        # the simulation object must be remade.
        #----
        # vsprm = GetVirtualSiteParameters(self.system)
        # if hasattr(self,'vsprm') and len(self.vsprm) > 0 and np.max(np.abs(vsprm - self.vsprm)) != 0.0:
        #     if hasattr(self, 'simulation'):
        #         delattr(self, 'simulation')
        # self.vsprm = vsprm.copy()

        if hasattr(self, 'simulation'):
            UpdateSimulationParameters(self.system, self.simulation)
        else:
            self.create_simulation(**self.simkwargs)

    def optimize(self, shot=0, align=True, crit=1e-4):
        return super(SMIRNOFF, self).optimize(shot=shot,
                                              align=align,
                                              crit=crit,
                                              disable_vsite=True)

    def interaction_energy(self, fraga, fragb):
        """
        Calculate the interaction energy for two fragments.
        Because this creates two new objects and requires passing in the mol2 argument,
        the codes are copied and modified from the OpenMM class.
        """

        self.update_simulation()

        if self.name == 'A' or self.name == 'B':
            logger.error("Don't name the engine A or B!\n")
            raise RuntimeError

        # Create two subengines.
        if hasattr(self, 'target'):
            if not hasattr(self, 'A'):
                self.A = SMIRNOFF(name="A",
                                  mol=self.mol.atom_select(fraga),
                                  mol2=self.mol2,
                                  target=self.target)
            if not hasattr(self, 'B'):
                self.B = SMIRNOFF(name="B",
                                  mol=self.mol.atom_select(fragb),
                                  mol2=self.mol2,
                                  target=self.target)
        else:
            if not hasattr(self, 'A'):
                self.A = SMIRNOFF(name="A", mol=self.mol.atom_select(fraga), mol2=self.mol2, platname=self.platname, \
                                  precision=self.precision, offxml=self.offxml, mmopts=self.mmopts)
            if not hasattr(self, 'B'):
                self.B = SMIRNOFF(name="B", mol=self.mol.atom_select(fragb), mol2=self.mol2, platname=self.platname, \
                                  precision=self.precision, offxml=self.offxml, mmopts=self.mmopts)

        # Interaction energy needs to be in kcal/mol.
        D = self.energy()
        A = self.A.energy()
        B = self.B.energy()

        return (D - A - B) / 4.184

    def get_smirks_counter(self):
        """Get a counter for the time of appreance of each SMIRKS"""
        smirks_counter = Counter()
        molecule_force_list = self.forcefield.label_molecules(
            self.off_topology)
        for mol_idx, mol_forces in enumerate(molecule_force_list):
            for force_tag, force_dict in mol_forces.items():
                # e.g. force_tag = 'Bonds'
                for parameter in force_dict.values():
                    smirks_counter[parameter.smirks] += 1
        return smirks_counter
Example #11
0
class LegacyForceField:
    """Class to hold legacy forcefield for typing and parameter assignment.

    Parameters
    ----------
    forcefield : string
        name and version of the forcefield.

    Methods
    -------
    parametrize()
        Parametrize a molecular system.

    typing()
        Provide legacy typing for a molecular system.

    """
    def __init__(self, forcefield="gaff-1.81"):
        self.forcefield = forcefield
        self._prepare_forcefield()

    @staticmethod
    def _convert_to_off(mol):
        import openff.toolkit

        if isinstance(mol, esp.Graph):
            return mol.mol

        elif isinstance(mol, openff.toolkit.topology.molecule.Molecule):
            return mol
        elif isinstance(mol, rdkit.Chem.rdchem.Mol):
            return Molecule.from_rdkit(mol)
        elif "openeye" in str(
                type(mol)):  # because we don't want to depend on OE
            return Molecule.from_openeye(mol)

    def _prepare_forcefield(self):

        if "gaff" in self.forcefield:
            self._prepare_gaff()

        elif "smirnoff" in self.forcefield:
            # do nothing for now
            self._prepare_smirnoff()

        elif "openff" in self.forcefield:
            self._prepare_openff()

        else:
            raise NotImplementedError

    def _prepare_openff(self):

        from openff.toolkit.typing.engines.smirnoff import ForceField

        self.FF = ForceField("%s.offxml" % self.forcefield)

    def _prepare_smirnoff(self):

        from openff.toolkit.typing.engines.smirnoff import ForceField

        self.FF = ForceField("%s.offxml" % self.forcefield)

    def _prepare_gaff(self):
        import os
        import xml.etree.ElementTree as ET

        import openmmforcefields

        # get the openff.toolkits path
        openmmforcefields_path = os.path.dirname(openmmforcefields.__file__)

        # get the xml path
        ffxml_path = (openmmforcefields_path + "/ffxml/amber/gaff/ffxml/" +
                      self.forcefield + ".xml")

        # parse xml
        tree = ET.parse(ffxml_path)
        root = tree.getroot()
        nonbonded = list(root)[-1]
        atom_types = [atom.get("type") for atom in nonbonded.findall("Atom")]

        # remove redundant types
        [atom_types.remove(bad_type) for bad_type in REDUNDANT_TYPES.keys()]

        # compose the translation dictionaries
        str_2_idx = dict(zip(atom_types, range(len(atom_types))))
        idx_2_str = dict(zip(range(len(atom_types)), atom_types))

        # provide mapping for redundant types
        for bad_type, good_type in REDUNDANT_TYPES.items():
            str_2_idx[bad_type] = str_2_idx[good_type]

        # make translation dictionaries attributes of self
        self._str_2_idx = str_2_idx
        self._idx_2_str = idx_2_str

    def _type_gaff(self, g):
        """Type a molecular graph using gaff force fields."""
        # assert the forcefield is indeed of gaff family
        assert "gaff" in self.forcefield

        # make sure mol is in openff.toolkit format `
        mol = g.mol

        # import template generator
        from openmmforcefields.generators import GAFFTemplateGenerator

        gaff = GAFFTemplateGenerator(molecules=mol, forcefield=self.forcefield)

        # create temporary directory for running antechamber
        import os
        import shutil
        import tempfile

        tempdir = tempfile.mkdtemp()
        prefix = "molecule"
        input_sdf_filename = os.path.join(tempdir, prefix + ".sdf")
        gaff_mol2_filename = os.path.join(tempdir, prefix + ".gaff.mol2")
        frcmod_filename = os.path.join(tempdir, prefix + ".frcmod")

        # write sdf for input
        mol.to_file(input_sdf_filename, file_format="sdf")

        # run antechamber
        gaff._run_antechamber(
            molecule_filename=input_sdf_filename,
            input_format="mdl",
            gaff_mol2_filename=gaff_mol2_filename,
            frcmod_filename=frcmod_filename,
        )

        gaff._read_gaff_atom_types_from_mol2(gaff_mol2_filename, mol)
        gaff_types = [atom.gaff_type for atom in mol.atoms]
        shutil.rmtree(tempdir)

        # put types into graph object
        if g is None:
            g = esp.Graph(mol)

        g.nodes["n1"].data["legacy_typing"] = torch.tensor(
            [self._str_2_idx[atom] for atom in gaff_types])

        return g

    def _parametrize_gaff(self, g, n_max_phases=6):
        from openmmforcefields.generators import SystemGenerator

        # define a system generator
        system_generator = SystemGenerator(
            small_molecule_forcefield=self.forcefield, )

        mol = g.mol
        # mol.assign_partial_charges("formal_charge")
        # create system
        sys = system_generator.create_system(
            topology=mol.to_topology().to_openmm(),
            molecules=mol,
        )

        bond_lookup = {
            tuple(idxs.detach().numpy()): position
            for position, idxs in enumerate(g.nodes["n2"].data["idxs"])
        }

        angle_lookup = {
            tuple(idxs.detach().numpy()): position
            for position, idxs in enumerate(g.nodes["n3"].data["idxs"])
        }

        torsion_lookup = {
            tuple(idxs.detach().numpy()): position
            for position, idxs in enumerate(g.nodes["n4"].data["idxs"])
        }

        improper_lookup = {
            tuple(idxs.detach().numpy()): position
            for position, idxs in enumerate(
                g.nodes["n4_improper"].data["idxs"])
        }

        torsion_phases = torch.zeros(
            g.heterograph.number_of_nodes("n4"),
            n_max_phases,
        )

        torsion_periodicities = torch.zeros(
            g.heterograph.number_of_nodes("n4"),
            n_max_phases,
        )

        torsion_ks = torch.zeros(
            g.heterograph.number_of_nodes("n4"),
            n_max_phases,
        )

        improper_phases = torch.zeros(
            g.heterograph.number_of_nodes("n4"),
            n_max_phases,
        )

        improper_periodicities = torch.zeros(
            g.heterograph.number_of_nodes("n4"),
            n_max_phases,
        )

        improper_ks = torch.zeros(
            g.heterograph.number_of_nodes("n4"),
            n_max_phases,
        )

        for force in sys.getForces():
            name = force.__class__.__name__
            if "HarmonicBondForce" in name:
                assert (force.getNumBonds() *
                        2 == g.heterograph.number_of_nodes("n2"))

                g.nodes["n2"].data["eq_ref"] = torch.zeros(
                    force.getNumBonds() * 2, 1)

                g.nodes["n2"].data["k_ref"] = torch.zeros(
                    force.getNumBonds() * 2, 1)

                for idx in range(force.getNumBonds()):
                    idx0, idx1, eq, k = force.getBondParameters(idx)

                    position = bond_lookup[(idx0, idx1)]
                    g.nodes["n2"].data["eq_ref"][position] = eq.value_in_unit(
                        esp.units.DISTANCE_UNIT, )
                    g.nodes["n2"].data["k_ref"][position] = k.value_in_unit(
                        esp.units.FORCE_CONSTANT_UNIT, )

                    position = bond_lookup[(idx1, idx0)]
                    g.nodes["n2"].data["eq_ref"][position] = eq.value_in_unit(
                        esp.units.DISTANCE_UNIT, )
                    g.nodes["n2"].data["k_ref"][position] = k.value_in_unit(
                        esp.units.FORCE_CONSTANT_UNIT, )

            if "HarmonicAngleForce" in name:
                assert (force.getNumAngles() *
                        2 == g.heterograph.number_of_nodes("n3"))

                g.nodes["n3"].data["eq_ref"] = torch.zeros(
                    force.getNumAngles() * 2, 1)

                g.nodes["n3"].data["k_ref"] = torch.zeros(
                    force.getNumAngles() * 2, 1)

                for idx in range(force.getNumAngles()):
                    idx0, idx1, idx2, eq, k = force.getAngleParameters(idx)

                    position = angle_lookup[(idx0, idx1, idx2)]
                    g.nodes["n3"].data["eq_ref"][position] = eq.value_in_unit(
                        esp.units.ANGLE_UNIT, )
                    g.nodes["n3"].data["k_ref"][position] = k.value_in_unit(
                        esp.units.ANGLE_FORCE_CONSTANT_UNIT, )

                    position = angle_lookup[(idx2, idx1, idx0)]
                    g.nodes["n3"].data["eq_ref"][position] = eq.value_in_unit(
                        esp.units.ANGLE_UNIT, )
                    g.nodes["n3"].data["k_ref"][position] = k.value_in_unit(
                        esp.units.ANGLE_FORCE_CONSTANT_UNIT, )

            if "PeriodicTorsionForce" in name:
                for idx in range(force.getNumTorsions()):
                    (
                        idx0,
                        idx1,
                        idx2,
                        idx3,
                        periodicity,
                        phase,
                        k,
                    ) = force.getTorsionParameters(idx)

                    if (idx0, idx1, idx2, idx3) in torsion_lookup:
                        position = torsion_lookup[(idx0, idx1, idx2, idx3)]
                        for sub_idx in range(n_max_phases):
                            if torsion_ks[position, sub_idx] == 0:
                                torsion_ks[position,
                                           sub_idx] = 0.5 * k.value_in_unit(
                                               esp.units.ENERGY_UNIT)
                                torsion_phases[position,
                                               sub_idx] = phase.value_in_unit(
                                                   esp.units.ANGLE_UNIT)
                                torsion_periodicities[position,
                                                      sub_idx] = periodicity

                                position = torsion_lookup[(idx3, idx2, idx1,
                                                           idx0)]
                                torsion_ks[position,
                                           sub_idx] = 0.5 * k.value_in_unit(
                                               esp.units.ENERGY_UNIT)
                                torsion_phases[position,
                                               sub_idx] = phase.value_in_unit(
                                                   esp.units.ANGLE_UNIT)
                                torsion_periodicities[position,
                                                      sub_idx] = periodicity
                                break

            g.heterograph.apply_nodes(
                lambda nodes: {
                    "k_ref": torsion_ks,
                    "periodicity_ref": torsion_periodicities,
                    "phases_ref": torsion_phases,
                },
                ntype="n4",
            )
            """
            g.heterograph.apply_nodes(
                    lambda nodes: {
                        "k_ref": improper_ks,
                        "periodicity_ref": improper_periodicities,
                        "phases_ref": improper_phases,
                    },
                    ntype="n4_improper"
            )

            """
        """
        def apply_torsion(node, n_max_phases=6):
            phases = torch.zeros(
                g.heterograph.number_of_nodes("n4"), n_max_phases,
            )

            periodicity = torch.zeros(
                g.heterograph.number_of_nodes("n4"), n_max_phases,
            )

            k = torch.zeros(g.heterograph.number_of_nodes("n4"), n_max_phases,)

            for idx in range(g.heterograph.number_of_nodes("n4")):
                idxs = tuple(node.data["idxs"][idx].numpy())
                if idxs in force:
                    _force = force[idxs]
                    for sub_idx in range(len(_force.periodicity)):
                        if hasattr(_force, "k%s" % sub_idx):
                            k[idx, sub_idx] = getattr(
                                _force, "k%s" % sub_idx
                            ).value_in_unit(esp.units.ENERGY_UNIT)

                            phases[idx, sub_idx] = getattr(
                                _force, "phase%s" % sub_idx
                            ).value_in_unit(esp.units.ANGLE_UNIT)

                            periodicity[idx, sub_idx] = getattr(
                                _force, "periodicity%s" % sub_idx
                            )

            return {
                "k_ref": k,
                "periodicity_ref": periodicity,
                "phases_ref": phases,
            }

        g.heterograph.apply_nodes(apply_torsion, ntype="n4")
        """

        return g

    def _parametrize_smirnoff(self, g):
        # mol = self._convert_to_off(mol)
        forces = self.FF.label_molecules(g.mol.to_topology())[0]

        g.heterograph.apply_nodes(
            lambda node: {
                "k_ref":
                2.0 * torch.Tensor([
                    forces["Bonds"][tuple(node.data["idxs"][idx].numpy())].k.
                    value_in_unit(esp.units.FORCE_CONSTANT_UNIT)
                    for idx in range(node.data["idxs"].shape[0])
                ])[:, None]
            },
            ntype="n2",
        )

        g.heterograph.apply_nodes(
            lambda node: {
                "eq_ref":
                torch.Tensor([
                    forces["Bonds"][tuple(node.data["idxs"][idx].numpy())].
                    length.value_in_unit(esp.units.DISTANCE_UNIT)
                    for idx in range(node.data["idxs"].shape[0])
                ])[:, None]
            },
            ntype="n2",
        )

        g.heterograph.apply_nodes(
            lambda node: {
                "k_ref":
                2.0 * torch.Tensor(  # OpenFF records 1/2k as param
                    [
                        forces["Angles"][tuple(node.data["idxs"][idx].numpy())]
                        .k.value_in_unit(esp.units.ANGLE_FORCE_CONSTANT_UNIT)
                        for idx in range(node.data["idxs"].shape[0])
                    ])[:, None]
            },
            ntype="n3",
        )

        g.heterograph.apply_nodes(
            lambda node: {
                "eq_ref":
                torch.Tensor([
                    forces["Angles"][tuple(node.data["idxs"][idx].numpy())].
                    angle.value_in_unit(esp.units.ANGLE_UNIT)
                    for idx in range(node.data["idxs"].shape[0])
                ])[:, None]
            },
            ntype="n3",
        )

        g.heterograph.apply_nodes(
            lambda node: {
                "epsilon_ref":
                torch.Tensor([
                    forces["vdW"][
                        (idx, )].epsilon.value_in_unit(esp.units.ENERGY_UNIT)
                    for idx in range(g.heterograph.number_of_nodes("n1"))
                ])[:, None]
            },
            ntype="n1",
        )

        g.heterograph.apply_nodes(
            lambda node: {
                "sigma_ref":
                torch.Tensor([
                    forces["vdW"][(idx, )].rmin_half.value_in_unit(
                        esp.units.DISTANCE_UNIT)
                    for idx in range(g.heterograph.number_of_nodes("n1"))
                ])[:, None]
            },
            ntype="n1",
        )

        def apply_torsion(node, n_max_phases=6):
            phases = torch.zeros(
                g.heterograph.number_of_nodes("n4"),
                n_max_phases,
            )

            periodicity = torch.zeros(
                g.heterograph.number_of_nodes("n4"),
                n_max_phases,
            )

            k = torch.zeros(
                g.heterograph.number_of_nodes("n4"),
                n_max_phases,
            )

            force = forces["ProperTorsions"]

            for idx in range(g.heterograph.number_of_nodes("n4")):
                idxs = tuple(node.data["idxs"][idx].numpy())
                if idxs in force:
                    _force = force[idxs]
                    for sub_idx in range(len(_force.periodicity)):
                        if hasattr(_force, "k%s" % sub_idx):
                            k[idx, sub_idx] = getattr(
                                _force, "k%s" % sub_idx).value_in_unit(
                                    esp.units.ENERGY_UNIT)

                            phases[idx, sub_idx] = getattr(
                                _force, "phase%s" % sub_idx).value_in_unit(
                                    esp.units.ANGLE_UNIT)

                            periodicity[idx, sub_idx] = getattr(
                                _force, "periodicity%s" % sub_idx)

            return {
                "k_ref": k,
                "periodicity_ref": periodicity,
                "phases_ref": phases,
            }

        def apply_improper_torsion(node, n_max_phases=6):
            phases = torch.zeros(
                g.heterograph.number_of_nodes("n4_improper"),
                n_max_phases,
            )

            periodicity = torch.zeros(
                g.heterograph.number_of_nodes("n4_improper"),
                n_max_phases,
            )

            k = torch.zeros(
                g.heterograph.number_of_nodes("n4_improper"),
                n_max_phases,
            )

            force = forces["ImproperTorsions"]

            for idx in range(g.heterograph.number_of_nodes("n4_improper")):
                idxs = tuple(node.data["idxs"][idx].numpy())
                if idxs in force:
                    _force = force[idxs]
                    for sub_idx in range(len(_force.periodicity)):

                        if hasattr(_force, "k%s" % sub_idx):
                            k[idx, sub_idx] = getattr(
                                _force, "k%s" % sub_idx).value_in_unit(
                                    esp.units.ENERGY_UNIT)

                            phases[idx, sub_idx] = getattr(
                                _force, "phase%s" % sub_idx).value_in_unit(
                                    esp.units.ANGLE_UNIT)

                            periodicity[idx, sub_idx] = getattr(
                                _force, "periodicity%s" % sub_idx)

            return {
                "k_ref": k,
                "periodicity_ref": periodicity,
                "phases_ref": phases,
            }

        g.heterograph.apply_nodes(apply_torsion, ntype="n4")
        g.heterograph.apply_nodes(apply_improper_torsion, ntype="n4_improper")

        return g

    def baseline_energy(self, g, suffix=None):
        if suffix is None:
            suffix = "_" + self.forcefield

        from openmmforcefields.generators import SystemGenerator

        # define a system generator
        system_generator = SystemGenerator(
            small_molecule_forcefield=self.forcefield, )

        mol = g.mol
        # mol.assign_partial_charges("formal_charge")
        # create system
        system = system_generator.create_system(
            topology=mol.to_topology().to_openmm(),
            molecules=mol,
        )

        # parameterize topology
        topology = g.mol.to_topology().to_openmm()

        integrator = openmm.LangevinIntegrator(TEMPERATURE, COLLISION_RATE,
                                               STEP_SIZE)

        # create simulation
        simulation = Simulation(topology=topology,
                                system=system,
                                integrator=integrator)

        us = []

        xs = (Quantity(
            g.nodes["n1"].data["xyz"].detach().numpy(),
            esp.units.DISTANCE_UNIT,
        ).value_in_unit(unit.nanometer).transpose((1, 0, 2)))

        for x in xs:
            simulation.context.setPositions(x)
            us.append(
                simulation.context.getState(
                    getEnergy=True).getPotentialEnergy().value_in_unit(
                        esp.units.ENERGY_UNIT))

        g.nodes["g"].data["u%s" % suffix] = torch.tensor(us)[None, :]

        return g

    def _multi_typing_smirnoff(self, g):
        # mol = self._convert_to_off(mol)

        forces = self.FF.label_molecules(g.mol.to_topology())[0]

        g.heterograph.apply_nodes(
            lambda node: {
                "legacy_typing":
                torch.Tensor([
                    int(forces["Bonds"][tuple(node.data["idxs"][idx].numpy())].
                        id[1:]) for idx in range(node.data["idxs"].shape[0])
                ]).long()
            },
            ntype="n2",
        )

        g.heterograph.apply_nodes(
            lambda node: {
                "legacy_typing":
                torch.Tensor([
                    int(forces["Angles"][tuple(node.data["idxs"][idx].numpy())]
                        .id[1:]) for idx in range(node.data["idxs"].shape[0])
                ]).long()
            },
            ntype="n3",
        )

        g.heterograph.apply_nodes(
            lambda node: {
                "legacy_typing":
                torch.Tensor([
                    int(forces["vdW"][(idx, )].id[1:])
                    for idx in range(g.heterograph.number_of_nodes("n1"))
                ]).long()
            },
            ntype="n1",
        )

        return g

    def parametrize(self, g):
        """Parametrize a molecular graph."""
        if "smirnoff" in self.forcefield or "openff" in self.forcefield:
            return self._parametrize_smirnoff(g)

        elif "gaff" in self.forcefield:
            return self._parametrize_gaff(g)

        else:
            raise NotImplementedError

    def typing(self, g):
        """Type a molecular graph."""
        if "gaff" in self.forcefield:
            return self._type_gaff(g)

        else:
            raise NotImplementedError

    def multi_typing(self, g):
        """ Type a molecular graph for hetero nodes. """
        if "smirnoff" in self.forcefield:
            return self._multi_typing_smirnoff(g)

        else:
            raise NotImplementedError

    def __call__(self, *args, **kwargs):
        return self.typing(*args, **kwargs)