Exemplo n.º 1
0
 def test_ener_shift(self):
     np.random.seed(0)
     data = DeepmdDataSystem(['system_0', 'system_1'], 5, 10, 1.0)
     data.add('energy', 1, must=True)
     ener_shift0 = data.compute_energy_shift(rcond=1)
     all_stat = make_all_stat(data, 4, merge_sys=False)
     ener_shift1 = EnerFitting._compute_output_stats(all_stat, rcond=1)
     for ii in range(len(ener_shift0)):
         self.assertAlmostEqual(ener_shift0[ii], ener_shift1[ii])
Exemplo n.º 2
0
 def test_ntypes(self):
     batch_size = 3
     test_size = 2
     ds = DeepmdDataSystem(self.sys_name, batch_size, test_size, 2.0)
     ds.add('test', self.test_ndof, atomic=True, must=True)
     ds.add('null', self.test_ndof, atomic=True, must=False)
     self.assertEqual(ds.get_ntypes(), 3)
     self.assertEqual(ds.get_nbatches(), [2, 4, 3, 2])
     self.assertEqual(ds.get_nsystems(), self.nsys)
     self.assertEqual(list(ds.get_batch_size()), [batch_size] * 4)
Exemplo n.º 3
0
    def test_get_test(self):
        batch_size = 3
        test_size = 2
        ds = DeepmdDataSystem(self.sys_name, batch_size, test_size, 2.0)
        ds.add('test', self.test_ndof, atomic=True, must=True)
        ds.add('null', self.test_ndof, atomic=True, must=False)
        sys_idx = 0
        data = ds.get_test(sys_idx=sys_idx)
        self.assertEqual(list(data['type'][0]),
                         list(np.sort(self.atom_type[sys_idx])))
        self._in_array(np.load('sys_0/set.002/coord.npy'),
                       ds.get_sys(sys_idx).idx_map, 3, data['coord'])
        self._in_array(np.load('sys_0/set.002/test.npy'),
                       ds.get_sys(sys_idx).idx_map, self.test_ndof,
                       data['test'])
        self.assertAlmostEqual(
            np.linalg.norm(
                np.zeros([
                    self.nframes[sys_idx] + 2, self.natoms[sys_idx] *
                    self.test_ndof
                ]) - data['null']), 0.0)

        sys_idx = 2
        data = ds.get_test(sys_idx=sys_idx)
        self.assertEqual(list(data['type'][0]),
                         list(np.sort(self.atom_type[sys_idx])))
        self._in_array(np.load('sys_2/set.002/coord.npy'),
                       ds.get_sys(sys_idx).idx_map, 3, data['coord'])
        self._in_array(np.load('sys_2/set.002/test.npy'),
                       ds.get_sys(sys_idx).idx_map, self.test_ndof,
                       data['test'])
        self.assertAlmostEqual(
            np.linalg.norm(
                np.zeros([
                    self.nframes[sys_idx] + 2, self.natoms[sys_idx] *
                    self.test_ndof
                ]) - data['null']), 0.0)
Exemplo n.º 4
0
    def test_merge_all_stat(self):
        np.random.seed(0)
        data0 = DeepmdDataSystem(['system_0', 'system_1'], 5, 10, 1.0)
        data0.add('energy', 1, must=True)
        np.random.seed(0)
        data1 = DeepmdDataSystem(['system_0', 'system_1'], 5, 10, 1.0)
        data1.add('force', 3, atomic=True, must=True)
        np.random.seed(0)
        data2 = DeepmdDataSystem(['system_0', 'system_1'], 5, 10, 1.0)
        data2.add('force', 3, atomic=True, must=True)

        np.random.seed(0)
        all_stat_0 = make_all_stat(data0, 10, merge_sys=False)
        np.random.seed(0)
        all_stat_1 = make_all_stat(data1, 10, merge_sys=True)
        all_stat_2 = merge_sys_stat(all_stat_0)
        np.random.seed(0)
        all_stat_3 = _make_all_stat_ref(data2, 10)

        ####################################
        # only check if the energy is concatenated correctly
        ####################################
        dd = 'energy'
        # if 'find_' in dd: continue
        # if 'natoms_vec' in dd: continue
        # if 'default_mesh' in dd: continue
        # print(all_stat_2[dd])
        # print(dd, all_stat_1[dd])
        d1 = np.array(all_stat_1[dd])
        d2 = np.array(all_stat_2[dd])
        d3 = np.array(all_stat_3[dd])
        # print(dd)
        # print(d1.shape)
        # print(d2.shape)
        # self.assertEqual(all_stat_2[dd], all_stat_1[dd])
        self._comp_data(d1, d2)
        self._comp_data(d1, d3)