def test_harmonic_bond(): patterns = [['[#6X4:1]-[#6X4:2]', 0.1, 0.2], ['[#6X4:1]-[#6X3:2]', 99., 99.], ['[#6X4:1]-[#6X3:2]=[#8X1+0]', 99., 99.], ['[#6X3:1]-[#6X3:2]', 99., 99.], ['[#6X3:1]:[#6X3:2]', 99., 99.], ['[#6X3:1]=[#6X3:2]', 99., 99.], ['[#6:1]-[#7:2]', 0.1, 0.2], ['[#6X3:1]-[#7X3:2]', 99., 99.], ['[#6X4:1]-[#7X3:2]-[#6X3]=[#8X1+0]', 99., 99.], ['[#6X3:1](=[#8X1+0])-[#7X3:2]', 99., 99.], ['[#6X3:1]-[#7X2:2]', 99., 99.], ['[#6X3:1]:[#7X2,#7X3+1:2]', 99., 99.], ['[#6X3:1]=[#7X2,#7X3+1:2]', 99., 99.], ['[#6:1]-[#8:2]', 99., 99.], ['[#6X3:1]-[#8X1-1:2]', 99., 99.], ['[#6X4:1]-[#8X2H0:2]', 0.3, 0.4], ['[#6X3:1]-[#8X2:2]', 99., 99.], ['[#6X3:1]-[#8X2H1:2]', 99., 99.], ['[#6X3a:1]-[#8X2H0:2]', 99., 99.], ['[#6X3:1](=[#8X1])-[#8X2H0:2]', 99., 99.], ['[#6:1]=[#8X1+0,#8X2+1:2]', 99., 99.], ['[#6X3:1](~[#8X1])~[#8X1:2]', 99., 99.], ['[#6X3:1]~[#8X2+1:2]~[#6X3]', 99., 99.], ['[#6X2:1]-[#6:2]', 99., 99.], ['[#6X2:1]-[#6X4:2]', 99., 99.], ['[#6X2:1]=[#6X3:2]', 99., 99.], ['[#6:1]#[#7:2]', 99., 99.], ['[#6X2:1]#[#6X2:2]', 99., 99.], ['[#6X2:1]-[#8X2:2]', 99., 99.], ['[#6X2:1]-[#7:2]', 99., 99.], ['[#6X2:1]=[#7:2]', 99., 99.], ['[#16:1]=[#6:2]', 99., 99.], ['[#6X2:1]=[#16:2]', 99., 99.], ['[#7:1]-[#7:2]', 99., 99.], ['[#7X3:1]-[#7X2:2]', 99., 99.], ['[#7X2:1]-[#7X2:2]', 99., 99.], ['[#7:1]:[#7:2]', 99., 99.], ['[#7:1]=[#7:2]', 99., 99.], ['[#7+1:1]=[#7-1:2]', 99., 99.], ['[#7:1]#[#7:2]', 99., 99.], ['[#7:1]-[#8X2:2]', 99., 99.], ['[#7:1]~[#8X1:2]', 99., 99.], ['[#8X2:1]-[#8X2:2]', 99., 99.], ['[#16:1]-[#6:2]', 99., 99.], ['[#16:1]-[#1:2]', 99., 99.], ['[#16:1]-[#16:2]', 99., 99.], ['[#16:1]-[#9:2]', 99., 99.], ['[#16:1]-[#17:2]', 99., 99.], ['[#16:1]-[#35:2]', 99., 99.], ['[#16:1]-[#53:2]', 99., 99.], ['[#16X2,#16X1-1,#16X3+1:1]-[#6X4:2]', 99., 99.], ['[#16X2,#16X1-1,#16X3+1:1]-[#6X3:2]', 99., 99.], ['[#16X2:1]-[#7:2]', 99., 99.], ['[#16X2:1]-[#8X2:2]', 99., 99.], ['[#16X2:1]=[#8X1,#7X2:2]', 99., 99.], ['[#16X4,#16X3!+1:1]-[#6:2]', 99., 99.], ['[#16X4,#16X3:1]~[#7:2]', 99., 99.], ['[#16X4,#16X3:1]-[#8X2:2]', 99., 99.], ['[#16X4,#16X3:1]~[#8X1:2]', 99., 99.], ['[#15:1]-[#1:2]', 99., 99.], ['[#15:1]~[#6:2]', 99., 99.], ['[#15:1]-[#7:2]', 99., 99.], ['[#15:1]=[#7:2]', 99., 99.], ['[#15:1]~[#8X2:2]', 99., 99.], ['[#15:1]~[#8X1:2]', 99., 99.], ['[#16:1]-[#15:2]', 99., 99.], ['[#15:1]=[#16X1:2]', 99., 99.], ['[#6:1]-[#9:2]', 99., 99.], ['[#6X4:1]-[#9:2]', 0.6, 0.7], ['[#6:1]-[#17:2]', 99., 99.], ['[#6X4:1]-[#17:2]', 99., 99.], ['[#6:1]-[#35:2]', 99., 99.], ['[#6X4:1]-[#35:2]', 99., 99.], ['[#6:1]-[#53:2]', 99., 99.], ['[#6X4:1]-[#53:2]', 99., 99.], ['[#7:1]-[#9:2]', 99., 99.], ['[#7:1]-[#17:2]', 99., 99.], ['[#7:1]-[#35:2]', 99., 99.], ['[#7:1]-[#53:2]', 99., 99.], ['[#15:1]-[#9:2]', 99., 99.], ['[#15:1]-[#17:2]', 99., 99.], ['[#15:1]-[#35:2]', 99., 99.], ['[#15:1]-[#53:2]', 99., 99.], ['[#6X4:1]-[#1:2]', 99., 99.], ['[#6X3:1]-[#1:2]', 99., 99.], ['[#6X2:1]-[#1:2]', 99., 99.], ['[#7:1]-[#1:2]', 99., 99.], ['[#8:1]-[#1:2]', 99., 99.1]] smirks = [x[0] for x in patterns] params = np.array([[x[1], x[2]] for x in patterns]) props = None hbh = bonded.HarmonicBondHandler(smirks, params, None) obj = hbh.serialize() all_handlers = deserialize_handlers(bin_to_str(obj)) assert len(all_handlers) == 1 new_hbh = all_handlers[0] np.testing.assert_equal(new_hbh.smirks, hbh.smirks) np.testing.assert_equal(new_hbh.params, hbh.params) assert new_hbh.props == hbh.props
def test_harmonic_bond(): patterns = [ ['[#6X4:1]-[#6X4:2]', 0.1, 0.2], ['[#6X4:1]-[#6X3:2]', 99., 99.], ['[#6X4:1]-[#6X3:2]=[#8X1+0]', 99., 99.], ['[#6X3:1]-[#6X3:2]', 99., 99.], ['[#6X3:1]:[#6X3:2]', 99., 99.], ['[#6X3:1]=[#6X3:2]', 99., 99.], ['[#6:1]-[#7:2]',0.1, 0.2], ['[#6X3:1]-[#7X3:2]', 99., 99.], ['[#6X4:1]-[#7X3:2]-[#6X3]=[#8X1+0]', 99., 99.], ['[#6X3:1](=[#8X1+0])-[#7X3:2]', 99., 99.], ['[#6X3:1]-[#7X2:2]', 99., 99.], ['[#6X3:1]:[#7X2,#7X3+1:2]', 99., 99.], ['[#6X3:1]=[#7X2,#7X3+1:2]', 99., 99.], ['[#6:1]-[#8:2]', 99., 99.], ['[#6X3:1]-[#8X1-1:2]', 99., 99.], ['[#6X4:1]-[#8X2H0:2]', 0.3, 0.4], ['[#6X3:1]-[#8X2:2]', 99., 99.], ['[#6X3:1]-[#8X2H1:2]', 99., 99.], ['[#6X3a:1]-[#8X2H0:2]', 99., 99.], ['[#6X3:1](=[#8X1])-[#8X2H0:2]', 99., 99.], ['[#6:1]=[#8X1+0,#8X2+1:2]', 99., 99.], ['[#6X3:1](~[#8X1])~[#8X1:2]', 99., 99.], ['[#6X3:1]~[#8X2+1:2]~[#6X3]', 99., 99.], ['[#6X2:1]-[#6:2]', 99., 99.], ['[#6X2:1]-[#6X4:2]', 99., 99.], ['[#6X2:1]=[#6X3:2]', 99., 99.], ['[#6:1]#[#7:2]', 99., 99.], ['[#6X2:1]#[#6X2:2]', 99., 99.], ['[#6X2:1]-[#8X2:2]', 99., 99.], ['[#6X2:1]-[#7:2]', 99., 99.], ['[#6X2:1]=[#7:2]', 99., 99.], ['[#16:1]=[#6:2]', 99., 99.], ['[#6X2:1]=[#16:2]', 99., 99.], ['[#7:1]-[#7:2]', 99., 99.], ['[#7X3:1]-[#7X2:2]', 99., 99.], ['[#7X2:1]-[#7X2:2]', 99., 99.], ['[#7:1]:[#7:2]', 99., 99.], ['[#7:1]=[#7:2]', 99., 99.], ['[#7+1:1]=[#7-1:2]', 99., 99.], ['[#7:1]#[#7:2]', 99., 99.], ['[#7:1]-[#8X2:2]', 99., 99.], ['[#7:1]~[#8X1:2]', 99., 99.], ['[#8X2:1]-[#8X2:2]', 99., 99.], ['[#16:1]-[#6:2]', 99., 99.], ['[#16:1]-[#1:2]', 99., 99.], ['[#16:1]-[#16:2]', 99., 99.], ['[#16:1]-[#9:2]', 99., 99.], ['[#16:1]-[#17:2]', 99., 99.], ['[#16:1]-[#35:2]', 99., 99.], ['[#16:1]-[#53:2]', 99., 99.], ['[#16X2,#16X1-1,#16X3+1:1]-[#6X4:2]', 99., 99.], ['[#16X2,#16X1-1,#16X3+1:1]-[#6X3:2]', 99., 99.], ['[#16X2:1]-[#7:2]', 99., 99.], ['[#16X2:1]-[#8X2:2]', 99., 99.], ['[#16X2:1]=[#8X1,#7X2:2]', 99., 99.], ['[#16X4,#16X3!+1:1]-[#6:2]', 99., 99.], ['[#16X4,#16X3:1]~[#7:2]', 99., 99.], ['[#16X4,#16X3:1]-[#8X2:2]', 99., 99.], ['[#16X4,#16X3:1]~[#8X1:2]', 99., 99.], ['[#15:1]-[#1:2]', 99., 99.], ['[#15:1]~[#6:2]', 99., 99.], ['[#15:1]-[#7:2]', 99., 99.], ['[#15:1]=[#7:2]', 99., 99.], ['[#15:1]~[#8X2:2]', 99., 99.], ['[#15:1]~[#8X1:2]', 99., 99.], ['[#16:1]-[#15:2]', 99., 99.], ['[#15:1]=[#16X1:2]', 99., 99.], ['[#6:1]-[#9:2]', 99., 99.], ['[#6X4:1]-[#9:2]', 0.6, 0.7], ['[#6:1]-[#17:2]', 99., 99.], ['[#6X4:1]-[#17:2]', 99., 99.], ['[#6:1]-[#35:2]', 99., 99.], ['[#6X4:1]-[#35:2]', 99., 99.], ['[#6:1]-[#53:2]', 99., 99.], ['[#6X4:1]-[#53:2]', 99., 99.], ['[#7:1]-[#9:2]', 99., 99.], ['[#7:1]-[#17:2]', 99., 99.], ['[#7:1]-[#35:2]', 99., 99.], ['[#7:1]-[#53:2]', 99., 99.], ['[#15:1]-[#9:2]', 99., 99.], ['[#15:1]-[#17:2]', 99., 99.], ['[#15:1]-[#35:2]', 99., 99.], ['[#15:1]-[#53:2]', 99., 99.], ['[#6X4:1]-[#1:2]', 99., 99.], ['[#6X3:1]-[#1:2]', 99., 99.], ['[#6X2:1]-[#1:2]', 99., 99.], ['[#7:1]-[#1:2]', 99., 99.], ['[#8:1]-[#1:2]', 99., 99.1] ] smirks = [x[0] for x in patterns] params = np.array([[x[1], x[2]] for x in patterns]) props = None hbh = bonded.HarmonicBondHandler(smirks, params, props) mol = Chem.MolFromSmiles("C1CNCOC1F") bond_params, bond_idxs = hbh.parameterize(mol) assert bond_idxs.shape == (mol.GetNumBonds(), 2) assert bond_params.shape == (mol.GetNumBonds(), 2) bonded_param_adjoints = np.random.randn(*bond_params.shape) bond_params_new, bond_vjp_fn, bond_idxs_new = jax.vjp(functools.partial(hbh.partial_parameterize, mol=mol), hbh.params, has_aux=True) np.testing.assert_array_equal(bond_params_new, bond_params) np.testing.assert_array_equal(bond_idxs_new, bond_idxs) # test that we can use the adjoints ff_adjoints = bond_vjp_fn(bonded_param_adjoints)[0] # if a parameter is > 99 then its adjoint should be zero (converse isn't necessarily true since) mask = np.argwhere(bond_params > 90) assert np.all(ff_adjoints[mask] == 0.0) == True