def test_proper_torsion(): # proper torsions have a variadic number of terms patterns = [ ['[*:1]-[#6X3:2]=[#6X3:3]-[*:4]', [[99., 99., 99.]]], ['[*:1]-[#6X3:2]=[#6X3:3]-[#35:4]', [[99., 99., 99.]]], ['[#9:1]-[#6X3:2]=[#6X3:3]-[#35:4]', [[1., 2., 3.], [4., 5., 6.]]], [ '[#35:1]-[#6X3:2]=[#6X3:3]-[#35:4]', [[7., 8., 9.], [1., 3., 5.], [4., 4., 4.]] ], ['[#9:1]-[#6X3:2]=[#6X3:3]-[#9:4]', [[7., 8., 9.]]], ] smirks = [x[0] for x in patterns] params = [x[1] for x in patterns] props = None ph = bonded.ProperTorsionHandler(smirks, params, None) obj = ph.serialize() all_handlers = deserialize_handlers(bin_to_str(obj)) assert len(all_handlers) == 1 new_ph = all_handlers[0] np.testing.assert_equal(new_ph.smirks, ph.smirks) np.testing.assert_equal(new_ph.params, ph.params) assert new_ph.props == ph.props
def test_proper_torsion(): # proper torsions have a variadic number of terms patterns = [ ['[*:1]-[#6X3:2]=[#6X3:3]-[*:4]', [[99., 99., 99.]]], ['[*:1]-[#6X3:2]=[#6X3:3]-[#35:4]', [[99., 99., 99.]]], ['[#9:1]-[#6X3:2]=[#6X3:3]-[#35:4]', [[1., 2., 3.], [4., 5., 6.]]], [ '[#35:1]-[#6X3:2]=[#6X3:3]-[#35:4]', [[7., 8., 9.], [1., 3., 5.], [4., 4., 4.]] ], ['[#9:1]-[#6X3:2]=[#6X3:3]-[#9:4]', [[7., 8., 9.]]], ] smirks = [x[0] for x in patterns] params = [x[1] for x in patterns] props = None hbh = bonded.ProperTorsionHandler(smirks, params, props) mol = Chem.MolFromSmiles("FC(Br)=C(Br)F") torsion_idxs, (torsion_params, torsion_vjp_fn) = hbh.parameterize(mol) assert torsion_idxs.shape == (8, 4) assert torsion_params.shape == (8, 3) torsion_param_adjoints = np.random.randn(*torsion_params.shape) ff_adjoints = torsion_vjp_fn(torsion_param_adjoints)[0] mask = np.argwhere(torsion_params > 90) assert np.all(ff_adjoints[mask] == 0.0) == True
def test_proper_torsion(): # proper torsions have a variadic number of terms patterns = [ ['[*:1]-[#6X3:2]=[#6X3:3]-[*:4]', [[99., 99., 99.]]], ['[*:1]-[#6X3:2]=[#6X3:3]-[#35:4]', [[99., 99., 99.]]], ['[#9:1]-[#6X3:2]=[#6X3:3]-[#35:4]', [[1., 2., 3.], [4., 5., 6.]]], ['[#35:1]-[#6X3:2]=[#6X3:3]-[#35:4]', [[7., 8., 9.], [1., 3., 5.], [4., 4., 4.]]], ['[#9:1]-[#6X3:2]=[#6X3:3]-[#9:4]', [[7., 8., 9.]]], ] smirks = [x[0] for x in patterns] params = [x[1] for x in patterns] props = None pth = bonded.ProperTorsionHandler(smirks, params, props) mol = Chem.MolFromSmiles("FC(Br)=C(Br)F") torsion_params, torsion_idxs = pth.parameterize(mol) assert torsion_idxs.shape == (8, 4) assert torsion_params.shape == (8, 3) torsion_params_new, torsion_vjp_fn, torsion_idxs_new = jax.vjp(functools.partial(pth.partial_parameterize, mol=mol), pth.params, has_aux=True) np.testing.assert_array_equal(torsion_params_new, torsion_params) np.testing.assert_array_equal(torsion_idxs_new, torsion_idxs) torsion_param_adjoints = np.random.randn(*torsion_params.shape) ff_adjoints = torsion_vjp_fn(torsion_param_adjoints)[0] mask = np.argwhere(torsion_params > 90) assert np.all(ff_adjoints[mask] == 0.0) == True