Пример #1
0
    def parameterize_nonbonded(self, ff_q_params, ff_lj_params):

        # dummy is either "a or "b"
        q_params_a = self.ff.q_handle.partial_parameterize(
            ff_q_params, self.mol_a)
        q_params_b = self.ff.q_handle.partial_parameterize(
            ff_q_params, self.mol_b)
        lj_params_a = self.ff.lj_handle.partial_parameterize(
            ff_lj_params, self.mol_a)
        lj_params_b = self.ff.lj_handle.partial_parameterize(
            ff_lj_params, self.mol_b)

        q_params = jnp.concatenate([q_params_a, q_params_b])
        lj_params = jnp.concatenate([lj_params_a, lj_params_b])

        exclusion_idxs_a, scale_factors_a = nonbonded.generate_exclusion_idxs(
            self.mol_a,
            scale12=_SCALE_12,
            scale13=_SCALE_13,
            scale14=_SCALE_14)

        exclusion_idxs_b, scale_factors_b = nonbonded.generate_exclusion_idxs(
            self.mol_b,
            scale12=_SCALE_12,
            scale13=_SCALE_13,
            scale14=_SCALE_14)

        mutual_exclusions = []
        mutual_scale_factors = []

        NA = self.mol_a.GetNumAtoms()
        NB = self.mol_b.GetNumAtoms()

        for i in range(NA):
            for j in range(NB):
                mutual_exclusions.append([i, j + NA])
                mutual_scale_factors.append([1.0, 1.0])

        mutual_exclusions = np.array(mutual_exclusions)
        mutual_scale_factors = np.array(mutual_scale_factors)

        combined_exclusion_idxs = np.concatenate(
            [exclusion_idxs_a, exclusion_idxs_b + NA,
             mutual_exclusions]).astype(np.int32)

        combined_scale_factors = np.concatenate([
            np.stack([scale_factors_a, scale_factors_a], axis=1),
            np.stack([scale_factors_b, scale_factors_b], axis=1),
            mutual_scale_factors,
        ]).astype(np.float64)

        combined_lambda_plane_idxs = None
        combined_lambda_offset_idxs = None

        beta = _BETA
        cutoff = _CUTOFF  # solve for this analytically later

        qlj_params = jnp.concatenate(
            [jnp.reshape(q_params, (-1, 1)),
             jnp.reshape(lj_params, (-1, 2))],
            axis=1)

        return qlj_params, potentials.Nonbonded(
            combined_exclusion_idxs,
            combined_scale_factors,
            combined_lambda_plane_idxs,
            combined_lambda_offset_idxs,
            beta,
            cutoff,
        )
Пример #2
0
def combine_potentials(ff_handlers, guest_mol, host_system, precision):
    """
    This function is responsible for figuring out how to take two separate hamiltonians
    and combining them into one sensible alchemical system.

    Parameters
    ----------

    ff_handlers: list of forcefield handlers
        Small molecule forcefield handlers

    guest_mol: Chem.ROMol
        RDKit molecule

    host_system: openmm.System
        Host system to be deserialized

    precision: np.float32 or np.float64
        Numerical precision of the functional form

    Returns
    -------
    tuple
        Returns a list of lib.potentials objects, combined masses, and a list of
        their corresponding vjp_fns back into the forcefield

    """

    host_potentials, host_masses = openmm_deserializer.deserialize_system(
        host_system, precision, cutoff=1.0)

    host_nb_bp = None

    combined_potentials = []
    combined_vjp_fns = []

    for bp in host_potentials:
        if isinstance(bp, potentials.Nonbonded):
            # (ytz): hack to ensure we only have one nonbonded term
            assert host_nb_bp is None
            host_nb_bp = bp
        else:
            combined_potentials.append(bp)
            combined_vjp_fns.append([])

    guest_masses = np.array([a.GetMass() for a in guest_mol.GetAtoms()],
                            dtype=np.float64)

    num_guest_atoms = len(guest_masses)
    num_host_atoms = len(host_masses)

    combined_masses = np.concatenate([host_masses, guest_masses])

    for handle in ff_handlers:
        results = handle.parameterize(guest_mol)
        if isinstance(handle, bonded.HarmonicBondHandler):
            bond_idxs, (bond_params, vjp_fn) = results
            bond_idxs += num_host_atoms
            combined_potentials.append(
                potentials.HarmonicBond(bond_idxs,
                                        precision=precision).bind(bond_params))
            combined_vjp_fns.append([(handle, vjp_fn)])
        elif isinstance(handle, bonded.HarmonicAngleHandler):
            angle_idxs, (angle_params, vjp_fn) = results
            angle_idxs += num_host_atoms
            combined_potentials.append(
                potentials.HarmonicAngle(
                    angle_idxs, precision=precision).bind(angle_params))
            combined_vjp_fns.append([(handle, vjp_fn)])
        elif isinstance(handle, bonded.ProperTorsionHandler):
            torsion_idxs, (torsion_params, vjp_fn) = results
            torsion_idxs += num_host_atoms
            combined_potentials.append(
                potentials.PeriodicTorsion(
                    torsion_idxs, precision=precision).bind(torsion_params))
            combined_vjp_fns.append([(handle, vjp_fn)])
        elif isinstance(handle, bonded.ImproperTorsionHandler):
            torsion_idxs, (torsion_params, vjp_fn) = results
            torsion_idxs += num_host_atoms
            combined_potentials.append(
                potentials.PeriodicTorsion(
                    torsion_idxs, precision=precision).bind(torsion_params))
            combined_vjp_fns.append([(handle, vjp_fn)])
        elif isinstance(handle, nonbonded.AM1CCCHandler):
            charge_handle = handle
            guest_charge_params, guest_charge_vjp_fn = results
        elif isinstance(handle, nonbonded.LennardJonesHandler):
            guest_lj_params, guest_lj_vjp_fn = results
            lj_handle = handle
        else:
            print("Warning: skipping handler", handle)
            pass

    # process nonbonded terms
    combined_nb_params, (charge_vjp_fn, lj_vjp_fn) = nonbonded_vjps(
        guest_charge_params, guest_charge_vjp_fn, guest_lj_params,
        guest_lj_vjp_fn, host_nb_bp.params)

    # these vjp_fns take in adjoints of combined_params and returns derivatives
    # appropriate to the underlying handler
    combined_vjp_fns.append([(charge_handle, charge_vjp_fn),
                             (lj_handle, lj_vjp_fn)])

    # tbd change scale 14 for electrostatics
    guest_exclusion_idxs, guest_scale_factors = nonbonded.generate_exclusion_idxs(
        guest_mol, scale12=1.0, scale13=1.0, scale14=0.5)

    # allow the ligand to be alchemically decoupled
    # a value of one indicates that we allow the atom to be adjusted by the lambda value
    guest_lambda_offset_idxs = np.ones(len(guest_masses), dtype=np.int32)

    # use same scale factors until we modify 1-4s for electrostatics
    guest_scale_factors = np.stack([guest_scale_factors, guest_scale_factors],
                                   axis=1)

    combined_lambda_offset_idxs = np.concatenate(
        [host_nb_bp.get_lambda_offset_idxs(), guest_lambda_offset_idxs])
    combined_exclusion_idxs = np.concatenate([
        host_nb_bp.get_exclusion_idxs(), guest_exclusion_idxs + num_host_atoms
    ])
    combined_scales = np.concatenate(
        [host_nb_bp.get_scale_factors(), guest_scale_factors])
    combined_beta = 2.0

    combined_cutoff = 1.0  # nonbonded cutoff

    combined_potentials.append(
        potentials.Nonbonded(
            combined_exclusion_idxs,
            combined_scales,
            combined_lambda_offset_idxs,
            combined_beta,
            combined_cutoff,
            precision=precision,
        ).bind(combined_nb_params))

    return combined_potentials, combined_masses, combined_vjp_fns
Пример #3
0
    def parameterize_nonbonded(self, ff_q_params, ff_lj_params):
        num_guest_atoms = self.guest_topology.get_num_atoms()
        guest_qlj, guest_p = self.guest_topology.parameterize_nonbonded(
            ff_q_params, ff_lj_params)

        if isinstance(guest_p, potentials.NonbondedInterpolated):
            assert guest_qlj.shape[0] == num_guest_atoms * 2
            is_interpolated = True
        else:
            assert guest_qlj.shape[0] == num_guest_atoms
            is_interpolated = False

        # see if we're doing parameter interpolation
        assert guest_qlj.shape[1] == 3
        assert guest_p.get_beta() == self.host_nonbonded.get_beta()
        assert guest_p.get_cutoff() == self.host_nonbonded.get_cutoff()

        hg_exclusion_idxs = np.concatenate([
            self.host_nonbonded.get_exclusion_idxs(),
            guest_p.get_exclusion_idxs() + self.num_host_atoms
        ])
        hg_scale_factors = np.concatenate([
            self.host_nonbonded.get_scale_factors(),
            guest_p.get_scale_factors()
        ])
        hg_lambda_offset_idxs = np.concatenate([
            self.host_nonbonded.get_lambda_offset_idxs(),
            guest_p.get_lambda_offset_idxs()
        ])
        hg_lambda_plane_idxs = np.concatenate([
            self.host_nonbonded.get_lambda_plane_idxs(),
            guest_p.get_lambda_plane_idxs()
        ])

        if is_interpolated:
            # with parameter interpolation
            hg_nb_params_src = jnp.concatenate(
                [self.host_nonbonded.params, guest_qlj[:num_guest_atoms]])
            hg_nb_params_dst = jnp.concatenate(
                [self.host_nonbonded.params, guest_qlj[num_guest_atoms:]])
            hg_nb_params = jnp.concatenate(
                [hg_nb_params_src, hg_nb_params_dst])

            nb = potentials.NonbondedInterpolated(
                hg_exclusion_idxs,
                hg_scale_factors,
                hg_lambda_plane_idxs,
                hg_lambda_offset_idxs,
                guest_p.get_beta(),
                guest_p.get_cutoff(),
            )

            return hg_nb_params, nb
        else:
            # no parameter interpolation
            hg_nb_params = jnp.concatenate(
                [self.host_nonbonded.params, guest_qlj])

            return hg_nb_params, potentials.Nonbonded(
                hg_exclusion_idxs,
                hg_scale_factors,
                hg_lambda_plane_idxs,
                hg_lambda_offset_idxs,
                guest_p.get_beta(),
                guest_p.get_cutoff(),
            )
Пример #4
0
def prepare_water_system(x, lambda_plane_idxs, lambda_offset_idxs, p_scale,
                         cutoff):

    assert x.ndim == 2
    N = x.shape[0]
    # D = x.shape[1]

    assert N % 3 == 0

    params = np.stack(
        [
            (np.random.rand(N).astype(np.float64) - 0.5) *
            np.sqrt(ONE_4PI_EPS0),  # q
            np.random.rand(N).astype(np.float64) / 5.0,  # sig
            np.random.rand(N).astype(np.float64),  # eps
        ],
        axis=1,
    )

    params[:, 1] = params[:, 1] / 2
    params[:, 2] = np.sqrt(params[:, 2])

    scales = []
    exclusion_idxs = []
    for i in range(N // 3):
        O_idx = i * 3 + 0
        H1_idx = i * 3 + 1
        H2_idx = i * 3 + 2
        exclusion_idxs.append([O_idx, H1_idx])  # 1-2
        exclusion_idxs.append([O_idx, H2_idx])  # 1-2
        exclusion_idxs.append([H1_idx, H2_idx])  # 1-3

        scales.append([1.0, 1.0])
        scales.append([1.0, 1.0])
        scales.append([np.random.rand(), np.random.rand()])

    scales = np.array(scales, dtype=np.float64)
    exclusion_idxs = np.array(exclusion_idxs, dtype=np.int32)

    beta = 2.0

    test_potential = potentials.Nonbonded(exclusion_idxs, scales,
                                          lambda_plane_idxs,
                                          lambda_offset_idxs, beta, cutoff)

    charge_rescale_mask = np.ones((N, N))
    for (i, j), exc in zip(exclusion_idxs, scales[:, 0]):
        charge_rescale_mask[i][j] = 1 - exc
        charge_rescale_mask[j][i] = 1 - exc

    lj_rescale_mask = np.ones((N, N))
    for (i, j), exc in zip(exclusion_idxs, scales[:, 1]):
        lj_rescale_mask[i][j] = 1 - exc
        lj_rescale_mask[j][i] = 1 - exc

    ref_total_energy = functools.partial(
        nonbonded.nonbonded_v3,
        charge_rescale_mask=charge_rescale_mask,
        lj_rescale_mask=lj_rescale_mask,
        beta=beta,
        cutoff=cutoff,
        lambda_plane_idxs=lambda_plane_idxs,
        lambda_offset_idxs=lambda_offset_idxs,
        runtime_validate=False,
    )

    return params, ref_total_energy, test_potential
Пример #5
0
    def from_rdkit(cls, mol, ff_handlers):
        """
        Initialize a system from an RDKit ROMol. 
    
        Parameters
        ----------
        mol: Chem.ROMol
            RDKit ROMol. Should have graphical hydrogens in the topology.

        ff_handlers: list of forcefield handlers.
            openforcefield small molecule handlers.

        """
        masses = np.array([a.GetMass() for a in mol.GetAtoms()], dtype=np.float64)

        bound_potentials = []

        for handle in ff_handlers:
            results = handle.parameterize(mol)
            if isinstance(handle, bonded.HarmonicBondHandler):
                bond_params, bond_idxs = results
                bound_potentials.append(potentials.HarmonicBond(bond_idxs).bind(bond_params))
            elif isinstance(handle, bonded.HarmonicAngleHandler):
                angle_params, angle_idxs = results
                bound_potentials.append(potentials.HarmonicAngle(angle_idxs).bind(angle_params))
            elif isinstance(handle, bonded.ProperTorsionHandler):
                torsion_params, torsion_idxs = results
                bound_potentials.append(potentials.PeriodicTorsion(torsion_idxs).bind(torsion_params))
            elif isinstance(handle, bonded.ImproperTorsionHandler):
                torsion_params, torsion_idxs = results
                bound_potentials.append(potentials.PeriodicTorsion(torsion_idxs).bind(torsion_params))
            elif isinstance(handle, nonbonded.AM1CCCHandler):
                charge_handle = handle
                charge_params = results
            elif isinstance(handle, nonbonded.LennardJonesHandler):
                lj_params = results
                lj_handle = handle
            else:
                print("WARNING: skipping handler", handle)
                pass

        lambda_plane_idxs = np.zeros(len(masses), dtype=np.int32)
        lambda_offset_idxs = np.zeros(len(masses), dtype=np.int32)

        exclusion_idxs, scale_factors = nonbonded.generate_exclusion_idxs(
            mol,
            scale12=1.0,
            scale13=1.0,
            scale14=0.5
        )

        # use same scale factors until we modify 1-4s for electrostatics
        scale_factors = np.stack([scale_factors, scale_factors], axis=1)

        # (ytz) fix this later to not be so hard coded
        alpha = 2.0 # same as ewald alpha
        cutoff = 1.0 # nonbonded cutoff

        qlj_params = jnp.concatenate([
            jnp.reshape(charge_params, (-1, 1)),
            jnp.reshape(lj_params, (-1, 2))
        ], axis=1)


        bound_potentials.append(potentials.Nonbonded(
            exclusion_idxs,
            scale_factors,
            lambda_plane_idxs,
            lambda_offset_idxs,
            alpha,
            cutoff).bind(qlj_params))

        return cls(masses, bound_potentials)
Пример #6
0
    def combine(self, other):
        """
        Combine two recipes together. self will keep its original indexing,
        while other will be incremented. This method automatically increments
        indices in the potential functions accordingly. For nonbonded terms, the recipe
        does a straight forward concatenation of the lambda idxs.

        Parameters
        ----------
        other: Recipe
            the right hand side recipe to combine with

        Returns
        -------
        Recipe
            combined recipe

        """
        self_num_atoms = len(self.masses)
        combined_masses = np.concatenate([self.masses, other.masses])
        combined_bound_potentials = []

        for bp in self.bound_potentials:
            if isinstance(bp, potentials.Nonbonded):
                # save these parameters for the merge part.
                self_nb_params = bp.params
                self_nb_exclusions = bp.get_exclusion_idxs()
                self_nb_scale_factors = bp.get_scale_factors()
                self_nb_cutoff = bp.get_cutoff()
                self_nb_beta = bp.get_beta()
                self_nb_lambda_plane_idxs = bp.get_lambda_plane_idxs()
                self_nb_lambda_offset_idxs = bp.get_lambda_offset_idxs()
            else:
                combined_bound_potentials.append(bp)


        for full_obp in other.bound_potentials:
            # always deepcopy to prevent modifying original copy
            full_obp = copy.deepcopy(full_obp)

            # if this is a lambda potential we replace only the .u_fn part
            if isinstance(full_obp, potentials.LambdaPotential):
                full_obp.set_N(full_obp.get_N() + self_num_atoms)
                obp = full_obp.get_u_fn()
            else:
                obp = full_obp

            if isinstance(obp, potentials.HarmonicBond) or isinstance(obp, potentials.CoreRestraint):
                idxs = obp.get_idxs()
                idxs += self_num_atoms # modify inplace
            elif isinstance(obp, potentials.HarmonicAngle):
                idxs = obp.get_idxs()
                idxs += self_num_atoms # modify inplace
            elif isinstance(obp, potentials.PeriodicTorsion):
                idxs = obp.get_idxs()
                idxs += self_num_atoms # modify inplace
            elif isinstance(obp, potentials.CentroidRestraint):
                a_idxs = obp.get_a_idxs()
                a_idxs += self_num_atoms # modify inplace
                b_idxs = obp.get_b_idxs()
                b_idxs += self_num_atoms # modify inplace
                # adjust masses
                obp.set_masses(np.concatenate([np.zeros(self_num_atoms), obp.get_masses()]))
            elif isinstance(obp, potentials.Nonbonded):
                assert self_nb_cutoff == obp.get_cutoff()

                assert self_nb_beta == obp.get_beta()

                combined_nb_params = jnp.concatenate([self_nb_params, obp.params])
                combined_exclusion_idxs = np.concatenate([self_nb_exclusions, obp.get_exclusion_idxs() + self_num_atoms])
                combined_scale_factors = np.concatenate([self_nb_scale_factors, obp.get_scale_factors()])
                combined_lambda_offset_idxs = np.concatenate([self_nb_lambda_offset_idxs, obp.get_lambda_offset_idxs()])
                combined_lambda_plane_idxs = np.concatenate([self_nb_lambda_plane_idxs, obp.get_lambda_plane_idxs()])

                # (ytz): leave this in for now
                # sanity check to ensure that the chain rules are working
                dummy = np.ones_like(combined_nb_params)

                obp = potentials.Nonbonded(
                    combined_exclusion_idxs,
                    combined_scale_factors,
                    combined_lambda_plane_idxs,
                    combined_lambda_offset_idxs,
                    self_nb_beta,
                    self_nb_cutoff).bind(combined_nb_params)

            else:
                raise Exception("Unknown functional form")

            if isinstance(full_obp, potentials.LambdaPotential):
                combined_bound_potentials.append(full_obp)
            else:
                combined_bound_potentials.append(obp)

        return Recipe(combined_masses, combined_bound_potentials)
Пример #7
0
def deserialize_system(system, cutoff):
    """
    Deserialize an OpenMM XML file

    Parameters
    ----------
    system: openmm.System
        A system object to be deserialized

    Returns
    -------
    list of lib.Potential, masses

    Note: We add a small epsilon (1e-3) to all zero eps values to prevent
    a singularity from occuring in the lennard jones derivatives

    """

    masses = []

    for p in range(system.getNumParticles()):
        masses.append(value(system.getParticleMass(p)))

    N = len(masses)

    # this should not be a dict since we may have more than one instance of a given
    # force.
    bps = []

    for force in system.getForces():

        if isinstance(force, mm.HarmonicBondForce):
            bond_idxs = []
            bond_params = []

            for b_idx in range(force.getNumBonds()):
                src_idx, dst_idx, length, k = force.getBondParameters(b_idx)
                length = value(length)
                k = value(k)

                bond_idxs.append([src_idx, dst_idx])
                bond_params.append((k, length))

            bond_idxs = np.array(bond_idxs, dtype=np.int32)
            bond_params = np.array(bond_params, dtype=np.float64)
            bps.append(potentials.HarmonicBond(bond_idxs).bind(bond_params))

        if isinstance(force, mm.HarmonicAngleForce):

            angle_idxs = []
            angle_params = []

            for a_idx in range(force.getNumAngles()):

                src_idx, mid_idx, dst_idx, angle, k = force.getAngleParameters(
                    a_idx)
                angle = value(angle)
                k = value(k)

                angle_idxs.append([src_idx, mid_idx, dst_idx])
                angle_params.append((k, angle))

            angle_idxs = np.array(angle_idxs, dtype=np.int32)
            angle_params = np.array(angle_params, dtype=np.float64)

            bps.append(potentials.HarmonicAngle(angle_idxs).bind(angle_params))

        if isinstance(force, mm.PeriodicTorsionForce):

            torsion_idxs = []
            torsion_params = []

            for t_idx in range(force.getNumTorsions()):
                a_idx, b_idx, c_idx, d_idx, period, phase, k = force.getTorsionParameters(
                    t_idx)

                phase = value(phase)
                k = value(k)

                torsion_params.append((k, phase, period))
                torsion_idxs.append([a_idx, b_idx, c_idx, d_idx])

            torsion_idxs = np.array(torsion_idxs, dtype=np.int32)
            torsion_params = np.array(torsion_params, dtype=np.float64)
            bps.append(
                potentials.PeriodicTorsion(torsion_idxs).bind(torsion_params))

        if isinstance(force, mm.NonbondedForce):

            num_atoms = force.getNumParticles()

            charge_params = []
            lj_params = []

            for a_idx in range(num_atoms):

                charge, sig, eps = force.getParticleParameters(a_idx)
                charge = value(charge) * np.sqrt(constants.ONE_4PI_EPS0)

                sig = value(sig)
                eps = value(eps)

                # increment eps by 1e-3 if we have eps==0 to avoid a singularity in parameter derivatives
                # override default amber types

                # this doesn't work for water!
                # if eps == 0:
                # print("Warning: overriding eps by 1e-3 to avoid a singularity")
                # eps += 1e-3

                # charge_params.append(charge_idx)
                charge_params.append(charge)
                lj_params.append((sig, eps))

            charge_params = np.array(charge_params, dtype=np.float64)

            # print("Protein net charge:", np.sum(np.array(global_params)[charge_param_idxs]))
            lj_params = np.array(lj_params, dtype=np.float64)

            # 1 here means we fully remove the interaction
            # 1-2, 1-3
            # scale_full = insert_parameters(1.0, 20)

            # 1-4, remove half of the interaction
            # scale_half = insert_parameters(0.5, 21)

            exclusion_idxs = []
            scale_factors = []

            all_sig = lj_params[:, 0]
            all_eps = lj_params[:, 1]

            # validate exclusions/exceptions to make sure they make sense
            for a_idx in range(force.getNumExceptions()):

                # tbd charge scale factors
                src, dst, new_cp, new_sig, new_eps = force.getExceptionParameters(
                    a_idx)
                new_sig = value(new_sig)
                new_eps = value(new_eps)

                src_sig = all_sig[src]
                dst_sig = all_sig[dst]

                src_eps = all_eps[src]
                dst_eps = all_eps[dst]
                expected_sig = (src_sig + dst_sig) / 2
                expected_eps = np.sqrt(src_eps * dst_eps)

                exclusion_idxs.append([src, dst])

                # sanity check this (expected_eps can be zero), redo this thing

                # the lj_scale factor measures how much we *remove*
                if expected_eps == 0:
                    if new_eps == 0:
                        lj_scale_factor = 1
                    else:
                        raise RuntimeError(
                            "Divide by zero in epsilon calculation")
                else:
                    lj_scale_factor = 1 - new_eps / expected_eps

                scale_factors.append(lj_scale_factor)

                # tbd fix charge_scale_factors using new_cp
                if new_eps != 0:
                    np.testing.assert_almost_equal(expected_sig, new_sig)

            exclusion_idxs = np.array(exclusion_idxs, dtype=np.int32)

            lambda_plane_idxs = np.zeros(N, dtype=np.int32)
            lambda_offset_idxs = np.zeros(N, dtype=np.int32)

            # cutoff = 1000.0

            nb_params = np.concatenate(
                [np.expand_dims(charge_params, axis=1), lj_params], axis=1)

            # optimizations
            nb_params[:, 1] = nb_params[:, 1] / 2
            nb_params[:, 2] = np.sqrt(nb_params[:, 2])

            beta = 2.0  # erfc correction

            # use the same scale factors for electrostatics and lj
            scale_factors = np.stack([scale_factors, scale_factors], axis=1)

            bps.append(
                potentials.Nonbonded(exclusion_idxs, scale_factors,
                                     lambda_plane_idxs, lambda_offset_idxs,
                                     beta, cutoff).bind(nb_params))

            # nrg_fns.append(('Exclusions', (exclusion_idxs, scale_factors, es_scale_factors)))

    return bps, masses