Esempio n. 1
0
class TestDeepDipolePBC(unittest.TestCase):
    def setUp(self):
        convert_pbtxt_to_pb(
            str(tests_path / os.path.join("infer", "deepdipole.pbtxt")),
            "deepdipole.pb")
        self.dp = DeepDipole("deepdipole.pb")
        self.coords = np.array([
            12.83, 2.56, 2.18, 12.09, 2.87, 2.74, 00.25, 3.32, 1.68, 3.36,
            3.00, 1.81, 3.51, 2.51, 2.60, 4.27, 3.22, 1.56
        ])
        self.atype = [0, 1, 1, 0, 1, 1]
        self.box = np.array([13., 0., 0., 0., 13., 0., 0., 0., 13.])
        self.expected_d = np.array([
            -9.274180565967479195e-01, 2.698028341272042496e+00,
            2.521268387140979117e-01, 2.927260638453461628e+00,
            -8.571926301526779923e-01, 1.667785136187720063e+00
        ])

    def tearDown(self):
        os.remove("deepdipole.pb")

    def test_attrs(self):
        self.assertEqual(self.dp.get_ntypes(), 2)
        self.assertAlmostEqual(self.dp.get_rcut(), 4.0, places=default_places)
        self.assertEqual(self.dp.get_type_map(), ['O', 'H'])
        self.assertEqual(self.dp.get_sel_type(), [0])

    def test_1frame_atm(self):
        dd = self.dp.eval(self.coords, self.box, self.atype)
        # check shape of the returns
        nframes = 1
        natoms = len(self.atype)
        nsel = 2
        self.assertEqual(dd.shape, (nframes, nsel, 3))
        # check values
        np.testing.assert_almost_equal(dd.ravel(), self.expected_d,
                                       default_places)

    def test_2frame_atm(self):
        coords2 = np.concatenate((self.coords, self.coords))
        box2 = np.concatenate((self.box, self.box))
        dd = self.dp.eval(coords2, box2, self.atype)
        # check shape of the returns
        nframes = 2
        natoms = len(self.atype)
        nsel = 2
        self.assertEqual(dd.shape, (nframes, nsel, 3))
        # check values
        expected_d = np.concatenate((self.expected_d, self.expected_d))
        np.testing.assert_almost_equal(dd.ravel(), expected_d, default_places)
Esempio n. 2
0
class TestDeepDipoleNoPBC(unittest.TestCase):
    def setUp(self):
        convert_pbtxt_to_pb(
            str(tests_path / os.path.join("infer", "deepdipole.pbtxt")),
            "deepdipole.pb")
        self.dp = DeepDipole("deepdipole.pb")
        self.coords = np.array([
            12.83, 2.56, 2.18, 12.09, 2.87, 2.74, 00.25, 3.32, 1.68, 3.36,
            3.00, 1.81, 3.51, 2.51, 2.60, 4.27, 3.22, 1.56
        ])
        self.atype = [0, 1, 1, 0, 1, 1]
        self.box = np.array([20., 0., 0., 0., 20., 0., 0., 0., 20.])
        self.expected_d = np.array([
            -1.982092647058316e+00, 8.303361089028074e-01,
            1.499962003179265e+00, 2.927112547154802e+00,
            -8.572096473802318e-01, 1.667798310054391e+00
        ])

    def tearDown(self):
        os.remove("deepdipole.pb")

    def test_1frame_atm(self):
        dd = self.dp.eval(self.coords, None, self.atype)
        # check shape of the returns
        nframes = 1
        natoms = len(self.atype)
        nsel = 2
        self.assertEqual(dd.shape, (nframes, nsel, 3))
        # check values
        np.testing.assert_almost_equal(dd.ravel(), self.expected_d,
                                       default_places)

    def test_1frame_atm_large_box(self):
        dd = self.dp.eval(self.coords, self.box, self.atype)
        # check shape of the returns
        nframes = 1
        natoms = len(self.atype)
        nsel = 2
        self.assertEqual(dd.shape, (nframes, nsel, 3))
        # check values
        np.testing.assert_almost_equal(dd.ravel(), self.expected_d,
                                       default_places)
Esempio n. 3
0
class TestDeepDipoleNewPBC(unittest.TestCase):
    def setUp(self):
        convert_pbtxt_to_pb(
            str(tests_path / os.path.join("infer", "deepdipole_new.pbtxt")),
            "deepdipole_new.pb")
        self.dp = DeepDipole("deepdipole_new.pb")
        self.coords = np.array([
            12.83, 2.56, 2.18, 12.09, 2.87, 2.74, 00.25, 3.32, 1.68, 3.36,
            3.00, 1.81, 3.51, 2.51, 2.60, 4.27, 3.22, 1.56
        ])
        self.nout = 3
        self.atype = np.array([0, 1, 1, 0, 1, 1])
        self.box = np.array([13., 0., 0., 0., 13., 0., 0., 0., 13.])
        self.expected_t = np.array([
            -1.128427726201255282e-01, 2.654103846999197880e-01,
            2.625816377288122533e-02, 3.027556488877700680e-01,
            -7.475444785689989990e-02, 1.526291164572509684e-01
        ])
        self.expected_f = np.array([
            8.424897862241968738e-02, -3.823566783202275721e-02,
            3.570797165027734810e-01, 6.102563129736437997e-02,
            -1.351209759852018133e-01, -2.438224487466488510e-01,
            -1.403204771681088869e-01, 1.719596545791735875e-01,
            -1.136584427103610045e-01, 2.761686212947551955e-02,
            -7.247860200915196005e-02, 6.208831127377397591e-02,
            -2.605870723577520809e-01, -4.504074577536486268e-02,
            7.340240097998475266e-02, 2.280160774766013809e-01,
            1.189163370225677641e-01, -1.350895372995223886e-01,
            -4.294311497114180337e-02, 1.524802094783661577e-01,
            1.070451777645946290e-01, -1.259336332521076574e-01,
            -2.087610788959351760e-01, 9.447141346538817652e-02,
            1.668125597515543457e-01, 5.487037060760904805e-02,
            -2.014994036104674757e-01, -7.411985441205551361e-02,
            3.614456658821710300e-01, 2.901174891391154476e-01,
            -4.871926969937838414e-02, -1.252747945819455699e-01,
            -2.555459318266457558e-01, 1.249033125831290059e-01,
            -2.347603724902655176e-01, -3.458874493198500766e-02,
            3.563990394229877290e-01, 1.052342031228763047e-01,
            1.907268232932498031e-01, -2.432737821373903708e-01,
            1.016781829972335099e-01, -7.707616437996064884e-02,
            -1.139199805053340564e-01, -2.068592154909300040e-01,
            -1.156337826476897951e-01, 6.583817133933017596e-02,
            2.902207490750204344e-01, 9.945482314729316153e-02,
            7.986986504051810098e-02, -2.549975565538568079e-01,
            1.275343199697696051e-01, -1.449133131601115787e-01,
            -3.527636315034351350e-02, -2.250060193826620980e-01
        ])
        self.expected_v = np.array([
            3.479789535931299138e-02, 4.337414719007849292e-03,
            -3.647371468256610082e-03, 8.053492919528318708e-03,
            1.003834811499279773e-03, -8.441338187607602033e-04,
            -6.695998268698949256e-03, -8.346286793845711892e-04,
            7.018468440279366279e-04, -4.515896716004976635e-02,
            1.891794570218296306e-02, 3.417435352652402336e-02,
            9.998952222904963771e-02, -4.188750255541257711e-02,
            -7.566774655171297492e-02, 1.804286120725206444e-01,
            -7.558495911146115298e-02, -1.365405712981232755e-01,
            -1.002593446510361419e-01, -1.117945222697993429e-01,
            7.449172735713084637e-02, 7.770237313970995707e-02,
            1.313723119887387492e-01, -8.655414676270002661e-02,
            -4.973937467461287537e-02, -8.663006083493235421e-02,
            5.703914957966123994e-02, -3.382231967662072125e-02,
            -4.215813217482468345e-03, 3.545115660155720612e-03,
            -8.247565860499378454e-03, -1.028025206407854253e-03,
            8.644757417520612143e-04, 6.761330949063471332e-03,
            8.427721296283078580e-04, -7.086947453692606178e-04,
            -1.622698090933780493e-02, 1.305372051650728060e-01,
            -2.082599910094798112e-01, -7.109985131471197733e-03,
            2.202585658101286273e-02, -3.554509763049529952e-02,
            1.436400379134906459e-02, -3.554915857551419617e-02,
            5.763638171798115412e-02, 2.074946305037073946e-01,
            5.016353704485233822e-02, -5.700401936915034523e-02,
            1.082138666905367308e-01, 2.616159414496492877e-02,
            -2.972908425564194101e-02, -1.229314789425654392e-01,
            -2.971969820589494271e-02, 3.377238432488059716e-02,
            7.622024445219390681e-03, 9.500540384976005961e-04,
            -7.989090778275298932e-04, -2.952148931042387209e-02,
            -3.679732378636401541e-03, 3.094320409307891630e-03,
            -9.534268115386618486e-04, -1.188407357158671420e-04,
            9.993425503379762414e-05, 9.319088860655992679e-02,
            -3.903942630815338682e-02, -7.052283462118023871e-02,
            1.544831983829924038e-01, -6.471593445773991815e-02,
            -1.169062041817236081e-01, -6.990884596438741438e-02,
            2.928613817427033750e-02, 5.290399154061733306e-02,
            7.491400658274136037e-02, 1.273824184577304897e-01,
            -8.391492311946648075e-02, 3.543872837542783732e-02,
            4.324623973455964804e-02, -2.873418641045778418e-02,
            -8.444981234074398768e-02, -1.531171183141288306e-01,
            1.007308415346981068e-01, -6.396885751015785743e-03,
            -7.973455327045167592e-04, 6.704951070469818575e-04,
            2.915483242551994078e-02, 3.634030104030812076e-03,
            -3.055888951116827318e-03, 6.608747470375698129e-04,
            8.237532257692081912e-05, -6.927015762150179410e-05,
            -6.099175331115514430e-03, 2.402310352789886402e-02,
            -3.861491558256636286e-02, -2.583867422346154685e-02,
            6.050621302336450097e-02, -9.822840263095998503e-02,
            -3.827994718203701213e-02, 1.252239810257823327e-01,
            -2.018867305507059950e-01, 1.136620144506474833e-01,
            2.747872876828840599e-02, -3.122582814578225147e-02,
            -2.136319389661417989e-01, -5.164728194785846160e-02,
            5.869009312256637939e-02, -3.147575788810638014e-02,
            -7.609523885036708832e-03, 8.647186232996251914e-03,
            -5.990706138603461330e-03, -7.467169124604876177e-04,
            6.279210400235934152e-04, -9.287887182821588476e-04,
            -1.157696985960763821e-04, 9.735179200124630735e-05,
            -2.966271471326579340e-02, -3.697335544996301071e-03,
            3.109123071928715683e-03, 1.800225987816693740e-01,
            -7.541487246259104271e-02, -1.362333179969384966e-01,
            -7.524185541795300192e-02, 3.152023672914239238e-02,
            5.693978247845072477e-02, 5.703636164117102669e-02,
            -2.389361095778780308e-02, -4.316265205277792366e-02,
            -4.915584336537091176e-02, -8.674240294138457763e-02,
            5.709724154860432860e-02, -8.679070528401405804e-02,
            -1.572017650485294793e-01, 1.034201569997979520e-01,
            -3.557746655862283752e-02, -8.626268394893003844e-02,
            5.645546718878535764e-02, 6.848075985139651621e-03,
            8.535845420570665554e-04, -7.177870012752625602e-04,
            8.266638576582277997e-04, 1.030402542123569647e-04,
            -8.664748649675494882e-05, 2.991751925173294011e-02,
            3.729095884068693231e-03, -3.135830629785046203e-03,
            1.523793442834292522e-02, -3.873020552543556677e-02,
            6.275576045602117292e-02, -3.842536616563556329e-02,
            1.249268983543572881e-01, -2.014296501045876875e-01,
            1.288704808602599873e-02, -6.326999354443738066e-02,
            1.014064886873057153e-01, -1.318711149757016143e-01,
            -3.188092889522457091e-02, 3.622832829002789468e-02,
            -3.210149046681261276e-02, -7.760799893075580151e-03,
            8.819090787585878374e-03, -2.047554776382226327e-01,
            -4.950132426418570042e-02, 5.625150484566552450e-02
        ])
        self.expected_gt = self.expected_t.reshape(
            -1, self.nout).sum(0).reshape(-1)
        self.expected_gv = self.expected_v.reshape(1, self.nout, 6,
                                                   9).sum(-2).reshape(-1)

    def tearDown(self):
        os.remove("deepdipole_new.pb")

    def test_attrs(self):
        self.assertEqual(self.dp.get_ntypes(), 2)
        self.assertAlmostEqual(self.dp.get_rcut(), 4.0, places=default_places)
        self.assertEqual(self.dp.get_type_map(), ['O', 'H'])
        self.assertEqual(self.dp.get_sel_type(), [0])

    def test_1frame_old(self):
        gt = self.dp.eval(self.coords, self.box, self.atype, atomic=False)
        # check shape of the returns
        nframes = 1
        self.assertEqual(gt.shape, (nframes, self.nout))
        # check values
        np.testing.assert_almost_equal(gt.ravel(), self.expected_gt,
                                       default_places)

    def test_1frame_old_atm(self):
        at = self.dp.eval(self.coords, self.box, self.atype)
        # check shape of the returns
        nframes = 1
        natoms = len(self.atype)
        nsel = 2
        self.assertEqual(at.shape, (nframes, nsel, self.nout))
        # check values
        np.testing.assert_almost_equal(at.ravel(), self.expected_t,
                                       default_places)

    def test_2frame_old_atm(self):
        coords2 = np.concatenate((self.coords, self.coords))
        box2 = np.concatenate((self.box, self.box))
        at = self.dp.eval(coords2, box2, self.atype)
        # check shape of the returns
        nframes = 2
        natoms = len(self.atype)
        nsel = 2
        self.assertEqual(at.shape, (nframes, nsel, self.nout))
        # check values
        expected_d = np.concatenate((self.expected_t, self.expected_t))
        np.testing.assert_almost_equal(at.ravel(), expected_d, default_places)

    def test_1frame_full(self):
        gt, ff, vv = self.dp.eval_full(self.coords,
                                       self.box,
                                       self.atype,
                                       atomic=False)
        # check shape of the returns
        nframes = 1
        natoms = len(self.atype)
        self.assertEqual(gt.shape, (nframes, self.nout))
        self.assertEqual(ff.shape, (nframes, self.nout, natoms, 3))
        self.assertEqual(vv.shape, (nframes, self.nout, 9))
        # check values
        np.testing.assert_almost_equal(ff.ravel(), self.expected_f,
                                       default_places)
        np.testing.assert_almost_equal(gt.ravel(), self.expected_gt,
                                       default_places)
        np.testing.assert_almost_equal(vv.ravel(), self.expected_gv,
                                       default_places)

    def test_1frame_full_atm(self):
        gt, ff, vv, at, av = self.dp.eval_full(self.coords,
                                               self.box,
                                               self.atype,
                                               atomic=True)
        # check shape of the returns
        nframes = 1
        natoms = len(self.atype)
        nsel = 2
        self.assertEqual(gt.shape, (nframes, self.nout))
        self.assertEqual(ff.shape, (nframes, self.nout, natoms, 3))
        self.assertEqual(vv.shape, (nframes, self.nout, 9))
        self.assertEqual(at.shape, (nframes, nsel, self.nout))
        self.assertEqual(av.shape, (nframes, self.nout, natoms, 9))
        # check values
        np.testing.assert_almost_equal(ff.reshape([-1]),
                                       self.expected_f.reshape([-1]),
                                       decimal=default_places)
        np.testing.assert_almost_equal(at.reshape([-1]),
                                       self.expected_t.reshape([-1]),
                                       decimal=default_places)
        np.testing.assert_almost_equal(av.reshape([-1]),
                                       self.expected_v.reshape([-1]),
                                       decimal=default_places)
        np.testing.assert_almost_equal(gt.reshape([-1]),
                                       self.expected_gt.reshape([-1]),
                                       decimal=default_places)
        np.testing.assert_almost_equal(vv.reshape([-1]),
                                       self.expected_gv.reshape([-1]),
                                       decimal=default_places)

    def test_1frame_full_atm_shuffle(self):
        i_sf = [2, 1, 3, 0, 5, 4]
        isel_sf = [1, 0]
        gt, ff, vv, at, av = self.dp.eval_full(self.coords.reshape(
            -1, 3)[i_sf].reshape(-1),
                                               self.box,
                                               self.atype[i_sf],
                                               atomic=True)
        # check shape of the returns
        nframes = 1
        natoms = len(self.atype)
        nsel = 2
        self.assertEqual(gt.shape, (nframes, self.nout))
        self.assertEqual(ff.shape, (nframes, self.nout, natoms, 3))
        self.assertEqual(vv.shape, (nframes, self.nout, 9))
        self.assertEqual(at.shape, (nframes, nsel, self.nout))
        self.assertEqual(av.shape, (nframes, self.nout, natoms, 9))
        # recover the shuffled result
        nff = np.empty_like(ff)
        nav = np.empty_like(av)
        nat = np.empty_like(at)
        nff[:, :, i_sf] = ff
        nav[:, :, i_sf] = av
        nat[:, isel_sf] = at
        # check values
        np.testing.assert_almost_equal(nff.reshape([-1]),
                                       self.expected_f.reshape([-1]),
                                       decimal=default_places)
        np.testing.assert_almost_equal(nat.reshape([-1]),
                                       self.expected_t.reshape([-1]),
                                       decimal=default_places)
        np.testing.assert_almost_equal(nav.reshape([-1]),
                                       self.expected_v.reshape([-1]),
                                       decimal=default_places)
        np.testing.assert_almost_equal(gt.reshape([-1]),
                                       self.expected_gt.reshape([-1]),
                                       decimal=default_places)
        np.testing.assert_almost_equal(vv.reshape([-1]),
                                       self.expected_gv.reshape([-1]),
                                       decimal=default_places)

    def test_1frame_num_deriv(self):
        # numerical force
        num_f = -finite_difference(
            lambda coord: self.dp.eval(
                coord, self.box, self.atype, atomic=False).reshape(-1),
            self.coords).reshape(-1)
        np.testing.assert_allclose(num_f.reshape([-1]),
                                   self.expected_f.reshape([-1]),
                                   atol=1e-5)
        # numerical virial
        num_v = -(finite_difference(
            lambda box: self.dp.eval(strerch_box(self.coords, self.box, box),
                                     box,
                                     self.atype,
                                     atomic=False).reshape(-1),
            self.box).reshape(-1, 3, 3).transpose(0, 2, 1) @ self.box.reshape(
                3, 3)).reshape(-1)
        np.testing.assert_allclose(num_v.reshape([-1]),
                                   self.expected_gv.reshape([-1]),
                                   atol=1e-5)

    def test_2frame_full_atm(self):
        coords2 = np.concatenate((self.coords, self.coords))
        box2 = np.concatenate((self.box, self.box))
        gt, ff, vv, at, av = self.dp.eval_full(coords2,
                                               box2,
                                               self.atype,
                                               atomic=True)
        # check shape of the returns
        nframes = 2
        natoms = len(self.atype)
        nsel = 2
        self.assertEqual(gt.shape, (nframes, self.nout))
        self.assertEqual(ff.shape, (nframes, self.nout, natoms, 3))
        self.assertEqual(vv.shape, (nframes, self.nout, 9))
        self.assertEqual(at.shape, (nframes, nsel, self.nout))
        self.assertEqual(av.shape, (nframes, self.nout, natoms, 9))
        # check values
        expected_f = np.tile(self.expected_f.reshape(-1), nframes)
        expected_t = np.tile(self.expected_t.reshape(-1), nframes)
        expected_v = np.tile(self.expected_v.reshape(-1), nframes)
        expected_gt = np.tile(self.expected_gt.reshape(-1), nframes)
        expected_gv = np.tile(self.expected_gv.reshape(-1), nframes)
        np.testing.assert_almost_equal(ff.reshape([-1]),
                                       expected_f.reshape([-1]),
                                       decimal=default_places)
        np.testing.assert_almost_equal(at.reshape([-1]),
                                       expected_t.reshape([-1]),
                                       decimal=default_places)
        np.testing.assert_almost_equal(av.reshape([-1]),
                                       expected_v.reshape([-1]),
                                       decimal=default_places)
        np.testing.assert_almost_equal(gt.reshape([-1]),
                                       expected_gt.reshape([-1]),
                                       decimal=default_places)
        np.testing.assert_almost_equal(vv.reshape([-1]),
                                       expected_gv.reshape([-1]),
                                       decimal=default_places)