예제 #1
0
def test_harmonic_bond():

    patterns = [['[#6X4:1]-[#6X4:2]', 0.1,
                 0.2], ['[#6X4:1]-[#6X3:2]', 99., 99.],
                ['[#6X4:1]-[#6X3:2]=[#8X1+0]', 99., 99.],
                ['[#6X3:1]-[#6X3:2]', 99.,
                 99.], ['[#6X3:1]:[#6X3:2]', 99., 99.],
                ['[#6X3:1]=[#6X3:2]', 99., 99.], ['[#6:1]-[#7:2]', 0.1, 0.2],
                ['[#6X3:1]-[#7X3:2]', 99., 99.],
                ['[#6X4:1]-[#7X3:2]-[#6X3]=[#8X1+0]', 99., 99.],
                ['[#6X3:1](=[#8X1+0])-[#7X3:2]', 99., 99.],
                ['[#6X3:1]-[#7X2:2]', 99., 99.],
                ['[#6X3:1]:[#7X2,#7X3+1:2]', 99., 99.],
                ['[#6X3:1]=[#7X2,#7X3+1:2]', 99., 99.],
                ['[#6:1]-[#8:2]', 99., 99.], ['[#6X3:1]-[#8X1-1:2]', 99., 99.],
                ['[#6X4:1]-[#8X2H0:2]', 0.3, 0.4],
                ['[#6X3:1]-[#8X2:2]', 99., 99.],
                ['[#6X3:1]-[#8X2H1:2]', 99., 99.],
                ['[#6X3a:1]-[#8X2H0:2]', 99., 99.],
                ['[#6X3:1](=[#8X1])-[#8X2H0:2]', 99., 99.],
                ['[#6:1]=[#8X1+0,#8X2+1:2]', 99., 99.],
                ['[#6X3:1](~[#8X1])~[#8X1:2]', 99., 99.],
                ['[#6X3:1]~[#8X2+1:2]~[#6X3]', 99., 99.],
                ['[#6X2:1]-[#6:2]', 99., 99.], ['[#6X2:1]-[#6X4:2]', 99., 99.],
                ['[#6X2:1]=[#6X3:2]', 99., 99.], ['[#6:1]#[#7:2]', 99., 99.],
                ['[#6X2:1]#[#6X2:2]', 99.,
                 99.], ['[#6X2:1]-[#8X2:2]', 99., 99.],
                ['[#6X2:1]-[#7:2]', 99., 99.], ['[#6X2:1]=[#7:2]', 99., 99.],
                ['[#16:1]=[#6:2]', 99., 99.], ['[#6X2:1]=[#16:2]', 99., 99.],
                ['[#7:1]-[#7:2]', 99., 99.], ['[#7X3:1]-[#7X2:2]', 99., 99.],
                ['[#7X2:1]-[#7X2:2]', 99., 99.], ['[#7:1]:[#7:2]', 99., 99.],
                ['[#7:1]=[#7:2]', 99., 99.], ['[#7+1:1]=[#7-1:2]', 99., 99.],
                ['[#7:1]#[#7:2]', 99., 99.], ['[#7:1]-[#8X2:2]', 99., 99.],
                ['[#7:1]~[#8X1:2]', 99., 99.], ['[#8X2:1]-[#8X2:2]', 99., 99.],
                ['[#16:1]-[#6:2]', 99., 99.], ['[#16:1]-[#1:2]', 99., 99.],
                ['[#16:1]-[#16:2]', 99., 99.], ['[#16:1]-[#9:2]', 99., 99.],
                ['[#16:1]-[#17:2]', 99., 99.], ['[#16:1]-[#35:2]', 99., 99.],
                ['[#16:1]-[#53:2]', 99., 99.],
                ['[#16X2,#16X1-1,#16X3+1:1]-[#6X4:2]', 99., 99.],
                ['[#16X2,#16X1-1,#16X3+1:1]-[#6X3:2]', 99., 99.],
                ['[#16X2:1]-[#7:2]', 99., 99.],
                ['[#16X2:1]-[#8X2:2]', 99., 99.],
                ['[#16X2:1]=[#8X1,#7X2:2]', 99., 99.],
                ['[#16X4,#16X3!+1:1]-[#6:2]', 99., 99.],
                ['[#16X4,#16X3:1]~[#7:2]', 99., 99.],
                ['[#16X4,#16X3:1]-[#8X2:2]', 99., 99.],
                ['[#16X4,#16X3:1]~[#8X1:2]', 99., 99.],
                ['[#15:1]-[#1:2]', 99., 99.], ['[#15:1]~[#6:2]', 99., 99.],
                ['[#15:1]-[#7:2]', 99., 99.], ['[#15:1]=[#7:2]', 99., 99.],
                ['[#15:1]~[#8X2:2]', 99., 99.], ['[#15:1]~[#8X1:2]', 99., 99.],
                ['[#16:1]-[#15:2]', 99., 99.], ['[#15:1]=[#16X1:2]', 99., 99.],
                ['[#6:1]-[#9:2]', 99., 99.], ['[#6X4:1]-[#9:2]', 0.6, 0.7],
                ['[#6:1]-[#17:2]', 99., 99.], ['[#6X4:1]-[#17:2]', 99., 99.],
                ['[#6:1]-[#35:2]', 99., 99.], ['[#6X4:1]-[#35:2]', 99., 99.],
                ['[#6:1]-[#53:2]', 99., 99.], ['[#6X4:1]-[#53:2]', 99., 99.],
                ['[#7:1]-[#9:2]', 99., 99.], ['[#7:1]-[#17:2]', 99., 99.],
                ['[#7:1]-[#35:2]', 99., 99.], ['[#7:1]-[#53:2]', 99., 99.],
                ['[#15:1]-[#9:2]', 99., 99.], ['[#15:1]-[#17:2]', 99., 99.],
                ['[#15:1]-[#35:2]', 99., 99.], ['[#15:1]-[#53:2]', 99., 99.],
                ['[#6X4:1]-[#1:2]', 99., 99.], ['[#6X3:1]-[#1:2]', 99., 99.],
                ['[#6X2:1]-[#1:2]', 99., 99.], ['[#7:1]-[#1:2]', 99., 99.],
                ['[#8:1]-[#1:2]', 99., 99.1]]

    smirks = [x[0] for x in patterns]
    params = np.array([[x[1], x[2]] for x in patterns])
    props = None
    hbh = bonded.HarmonicBondHandler(smirks, params, None)

    obj = hbh.serialize()
    all_handlers = deserialize_handlers(bin_to_str(obj))

    assert len(all_handlers) == 1

    new_hbh = all_handlers[0]
    np.testing.assert_equal(new_hbh.smirks, hbh.smirks)
    np.testing.assert_equal(new_hbh.params, hbh.params)

    assert new_hbh.props == hbh.props
예제 #2
0
def test_harmonic_bond():

    patterns = [
        ['[#6X4:1]-[#6X4:2]', 0.1, 0.2],
        ['[#6X4:1]-[#6X3:2]', 99., 99.],
        ['[#6X4:1]-[#6X3:2]=[#8X1+0]', 99., 99.],
        ['[#6X3:1]-[#6X3:2]', 99., 99.],
        ['[#6X3:1]:[#6X3:2]', 99., 99.],
        ['[#6X3:1]=[#6X3:2]', 99., 99.],
        ['[#6:1]-[#7:2]',0.1, 0.2],
        ['[#6X3:1]-[#7X3:2]', 99., 99.],
        ['[#6X4:1]-[#7X3:2]-[#6X3]=[#8X1+0]', 99., 99.],
        ['[#6X3:1](=[#8X1+0])-[#7X3:2]', 99., 99.],
        ['[#6X3:1]-[#7X2:2]', 99., 99.],
        ['[#6X3:1]:[#7X2,#7X3+1:2]', 99., 99.],
        ['[#6X3:1]=[#7X2,#7X3+1:2]', 99., 99.],
        ['[#6:1]-[#8:2]', 99., 99.],
        ['[#6X3:1]-[#8X1-1:2]', 99., 99.],
        ['[#6X4:1]-[#8X2H0:2]', 0.3, 0.4],
        ['[#6X3:1]-[#8X2:2]', 99., 99.],
        ['[#6X3:1]-[#8X2H1:2]', 99., 99.],
        ['[#6X3a:1]-[#8X2H0:2]', 99., 99.],
        ['[#6X3:1](=[#8X1])-[#8X2H0:2]', 99., 99.],
        ['[#6:1]=[#8X1+0,#8X2+1:2]', 99., 99.],
        ['[#6X3:1](~[#8X1])~[#8X1:2]', 99., 99.],
        ['[#6X3:1]~[#8X2+1:2]~[#6X3]', 99., 99.],
        ['[#6X2:1]-[#6:2]', 99., 99.],
        ['[#6X2:1]-[#6X4:2]', 99., 99.],
        ['[#6X2:1]=[#6X3:2]', 99., 99.],
        ['[#6:1]#[#7:2]', 99., 99.],
        ['[#6X2:1]#[#6X2:2]', 99., 99.],
        ['[#6X2:1]-[#8X2:2]', 99., 99.],
        ['[#6X2:1]-[#7:2]', 99., 99.],
        ['[#6X2:1]=[#7:2]', 99., 99.],
        ['[#16:1]=[#6:2]', 99., 99.],
        ['[#6X2:1]=[#16:2]', 99., 99.],
        ['[#7:1]-[#7:2]', 99., 99.],
        ['[#7X3:1]-[#7X2:2]', 99., 99.],
        ['[#7X2:1]-[#7X2:2]', 99., 99.],
        ['[#7:1]:[#7:2]', 99., 99.],
        ['[#7:1]=[#7:2]', 99., 99.],
        ['[#7+1:1]=[#7-1:2]', 99., 99.],
        ['[#7:1]#[#7:2]', 99., 99.],
        ['[#7:1]-[#8X2:2]', 99., 99.],
        ['[#7:1]~[#8X1:2]', 99., 99.],
        ['[#8X2:1]-[#8X2:2]', 99., 99.],
        ['[#16:1]-[#6:2]', 99., 99.],
        ['[#16:1]-[#1:2]', 99., 99.],
        ['[#16:1]-[#16:2]', 99., 99.],
        ['[#16:1]-[#9:2]', 99., 99.],
        ['[#16:1]-[#17:2]', 99., 99.],
        ['[#16:1]-[#35:2]', 99., 99.],
        ['[#16:1]-[#53:2]', 99., 99.],
        ['[#16X2,#16X1-1,#16X3+1:1]-[#6X4:2]', 99., 99.],
        ['[#16X2,#16X1-1,#16X3+1:1]-[#6X3:2]', 99., 99.],
        ['[#16X2:1]-[#7:2]', 99., 99.],
        ['[#16X2:1]-[#8X2:2]', 99., 99.],
        ['[#16X2:1]=[#8X1,#7X2:2]', 99., 99.],
        ['[#16X4,#16X3!+1:1]-[#6:2]', 99., 99.],
        ['[#16X4,#16X3:1]~[#7:2]', 99., 99.],
        ['[#16X4,#16X3:1]-[#8X2:2]', 99., 99.],
        ['[#16X4,#16X3:1]~[#8X1:2]', 99., 99.],
        ['[#15:1]-[#1:2]', 99., 99.],
        ['[#15:1]~[#6:2]', 99., 99.],
        ['[#15:1]-[#7:2]', 99., 99.],
        ['[#15:1]=[#7:2]', 99., 99.],
        ['[#15:1]~[#8X2:2]', 99., 99.],
        ['[#15:1]~[#8X1:2]', 99., 99.],
        ['[#16:1]-[#15:2]', 99., 99.],
        ['[#15:1]=[#16X1:2]', 99., 99.],
        ['[#6:1]-[#9:2]', 99., 99.],
        ['[#6X4:1]-[#9:2]', 0.6, 0.7],
        ['[#6:1]-[#17:2]', 99., 99.],
        ['[#6X4:1]-[#17:2]', 99., 99.],
        ['[#6:1]-[#35:2]', 99., 99.],
        ['[#6X4:1]-[#35:2]', 99., 99.],
        ['[#6:1]-[#53:2]', 99., 99.],
        ['[#6X4:1]-[#53:2]', 99., 99.],
        ['[#7:1]-[#9:2]', 99., 99.],
        ['[#7:1]-[#17:2]', 99., 99.],
        ['[#7:1]-[#35:2]', 99., 99.],
        ['[#7:1]-[#53:2]', 99., 99.],
        ['[#15:1]-[#9:2]', 99., 99.],
        ['[#15:1]-[#17:2]', 99., 99.],
        ['[#15:1]-[#35:2]', 99., 99.],
        ['[#15:1]-[#53:2]', 99., 99.],
        ['[#6X4:1]-[#1:2]', 99., 99.],
        ['[#6X3:1]-[#1:2]', 99., 99.],
        ['[#6X2:1]-[#1:2]', 99., 99.],
        ['[#7:1]-[#1:2]', 99., 99.],
        ['[#8:1]-[#1:2]', 99., 99.1]
    ]


    smirks = [x[0] for x in patterns]
    params = np.array([[x[1], x[2]] for x in patterns])
    props = None
    hbh = bonded.HarmonicBondHandler(smirks, params, props)

    mol = Chem.MolFromSmiles("C1CNCOC1F")

    bond_params, bond_idxs = hbh.parameterize(mol)

    assert bond_idxs.shape == (mol.GetNumBonds(), 2)
    assert bond_params.shape == (mol.GetNumBonds(), 2)


    bonded_param_adjoints = np.random.randn(*bond_params.shape)

    bond_params_new, bond_vjp_fn, bond_idxs_new = jax.vjp(functools.partial(hbh.partial_parameterize, mol=mol), hbh.params, has_aux=True)

    np.testing.assert_array_equal(bond_params_new, bond_params)
    np.testing.assert_array_equal(bond_idxs_new, bond_idxs)

    # test that we can use the adjoints
    ff_adjoints = bond_vjp_fn(bonded_param_adjoints)[0]

    # if a parameter is > 99 then its adjoint should be zero (converse isn't necessarily true since)
    mask = np.argwhere(bond_params > 90)
    assert np.all(ff_adjoints[mask] == 0.0) == True