示例#1
0
 def __init__(self, jdata, descrpt):
     if not isinstance(descrpt, DescrptSeA):
         raise RuntimeError('PolarFittingSeA only supports DescrptSeA')
     self.ntypes = descrpt.get_ntypes()
     self.dim_descrpt = descrpt.get_dim_out()
     args = ClassArg()\
            .add('neuron',           list,   default = [120,120,120], alias = 'n_neuron')\
            .add('resnet_dt',        bool,   default = True)\
            .add('fit_diag',         bool,   default = True)\
            .add('diag_shift',       [list,float], default = [0.0 for ii in range(self.ntypes)])\
            .add('scale',            [list,float], default = [1.0 for ii in range(self.ntypes)])\
            .add('sel_type',         [list,int],   default = [ii for ii in range(self.ntypes)], alias = 'pol_type')\
            .add('seed',             int)\
            .add("activation_function", str ,   default = "tanh")\
            .add('precision',           str,    default = "default")
     class_data = args.parse(jdata)
     self.n_neuron = class_data['neuron']
     self.resnet_dt = class_data['resnet_dt']
     self.sel_type = class_data['sel_type']
     self.fit_diag = class_data['fit_diag']
     self.seed = class_data['seed']
     self.diag_shift = class_data['diag_shift']
     self.scale = class_data['scale']
     self.fitting_activation_fn = get_activation_func(
         class_data["activation_function"])
     self.fitting_precision = get_precision(class_data['precision'])
     if type(self.sel_type) is not list:
         self.sel_type = [self.sel_type]
     if type(self.diag_shift) is not list:
         self.diag_shift = [self.diag_shift]
     if type(self.scale) is not list:
         self.scale = [self.scale]
     self.dim_rot_mat_1 = descrpt.get_dim_rot_mat_1()
     self.dim_rot_mat = self.dim_rot_mat_1 * 3
     self.useBN = False
示例#2
0
文件: wfc.py 项目: y1xiaoc/deepmd-kit
 def __init__ (self, jdata, descrpt):
     if not isinstance(descrpt, DescrptLocFrame) :
         raise RuntimeError('WFC only supports DescrptLocFrame')
     self.ntypes = descrpt.get_ntypes()
     self.dim_descrpt = descrpt.get_dim_out()
     args = ClassArg()\
            .add('neuron',           list,   default = [120,120,120], alias = 'n_neuron')\
            .add('resnet_dt',        bool,   default = True)\
            .add('wfc_numb',         int,    must = True)\
            .add('sel_type',         [list,int],   default = [ii for ii in range(self.ntypes)], alias = 'wfc_type')\
            .add('seed',             int)\
            .add("activation_function", str, default = "tanh")\
            .add('precision',           str,    default = "default")\
            .add('uniform_seed',     bool, default = False)
     class_data = args.parse(jdata)
     self.n_neuron = class_data['neuron']
     self.resnet_dt = class_data['resnet_dt']
     self.wfc_numb = class_data['wfc_numb']
     self.sel_type = class_data['sel_type']
     self.seed = class_data['seed']
     self.uniform_seed = class_data['uniform_seed']
     self.seed_shift = one_layer_rand_seed_shift()
     self.fitting_activation_fn = get_activation_func(class_data["activation_function"])
     self.fitting_precision = get_precision(class_data['precision'])
     self.useBN = False
示例#3
0
 def __init__(self, jdata, descrpt):
     # model param
     self.ntypes = descrpt.get_ntypes()
     self.dim_descrpt = descrpt.get_dim_out()
     args = ClassArg()\
            .add('numb_fparam',      int,    default = 0)\
            .add('numb_aparam',      int,    default = 0)\
            .add('neuron',           list,   default = [120,120,120], alias = 'n_neuron')\
            .add('resnet_dt',        bool,   default = True)\
            .add('rcond',            float,  default = 1e-3) \
            .add('seed',             int)               \
            .add('atom_ener',        list,   default = [])\
            .add("activation_function", str,    default = "tanh")\
            .add("precision",           str, default = "default")\
            .add("trainable",        [list, bool], default = True)
     class_data = args.parse(jdata)
     self.numb_fparam = class_data['numb_fparam']
     self.numb_aparam = class_data['numb_aparam']
     self.n_neuron = class_data['neuron']
     self.resnet_dt = class_data['resnet_dt']
     self.rcond = class_data['rcond']
     self.seed = class_data['seed']
     self.fitting_activation_fn = get_activation_func(
         class_data["activation_function"])
     self.fitting_precision = get_precision(class_data['precision'])
     self.trainable = class_data['trainable']
     if type(self.trainable) is bool:
         self.trainable = [self.trainable] * (len(self.n_neuron) + 1)
     assert (len(self.trainable) == len(self.n_neuron) +
             1), 'length of trainable should be that of n_neuron + 1'
     self.atom_ener = []
     for at, ae in enumerate(class_data['atom_ener']):
         if ae is not None:
             self.atom_ener.append(
                 tf.constant(ae,
                             global_tf_float_precision,
                             name="atom_%d_ener" % at))
         else:
             self.atom_ener.append(None)
     self.useBN = False
     self.bias_atom_e = None
     # data requirement
     if self.numb_fparam > 0:
         add_data_requirement('fparam',
                              self.numb_fparam,
                              atomic=False,
                              must=True,
                              high_prec=False)
         self.fparam_avg = None
         self.fparam_std = None
         self.fparam_inv_std = None
     if self.numb_aparam > 0:
         add_data_requirement('aparam',
                              self.numb_aparam,
                              atomic=True,
                              must=True,
                              high_prec=False)
         self.aparam_avg = None
         self.aparam_std = None
         self.aparam_inv_std = None
示例#4
0
 def test_add_multi(self):
     ca = ClassArg()\
          .add('test',  int)\
          .add('test1', str)
     test_dict = {'test': 10, 'test1': 'foo'}
     ca.parse(test_dict)
     self.assertEqual(ca.get_dict(), {'test1': 'foo', 'test': 10})
示例#5
0
    def __init__(self, jdata):
        args = ClassArg()\
               .add('sel',      list,   must = True) \
               .add('rcut',     float,  default = 6.0) \
               .add('rcut_smth',float,  default = 5.5) \
               .add('neuron',   list,   default = [10, 20, 40]) \
               .add('axis_neuron', int, default = 4, alias = 'n_axis_neuron') \
               .add('resnet_dt',bool,   default = False) \
               .add('trainable',bool,   default = True) \
               .add('seed',     int)
        class_data = args.parse(jdata)
        self.sel_a = class_data['sel']
        self.rcut_r = class_data['rcut']
        self.rcut_r_smth = class_data['rcut_smth']
        self.filter_neuron = class_data['neuron']
        self.n_axis_neuron = class_data['axis_neuron']
        self.filter_resnet_dt = class_data['resnet_dt']
        self.seed = class_data['seed']
        self.trainable = class_data['trainable']

        # descrpt config
        self.sel_r = [0 for ii in range(len(self.sel_a))]
        self.ntypes = len(self.sel_a)
        assert (self.ntypes == len(self.sel_r))
        self.rcut_a = -1
        # numb of neighbors and numb of descrptors
        self.nnei_a = np.cumsum(self.sel_a)[-1]
        self.nnei_r = np.cumsum(self.sel_r)[-1]
        self.nnei = self.nnei_a + self.nnei_r
        self.ndescrpt_a = self.nnei_a * 4
        self.ndescrpt_r = self.nnei_r * 1
        self.ndescrpt = self.ndescrpt_a + self.ndescrpt_r

        self.useBN = False
示例#6
0
 def test_add_multi_types(self):
     ca = ClassArg()\
          .add('test',  [str, list])\
          .add('test1',  [str, list])
     test_dict = {'test': [10, 20], 'test1': 10}
     ca.parse(test_dict)
     self.assertEqual(ca.get_dict(), {'test': [10, 20], 'test1': '10'})
示例#7
0
 def test_add_default_overwrite(self):
     ca = ClassArg().add('test',
                         str,
                         alias=['test1', 'test2'],
                         default='bar')
     test_dict = {'test2': 'foo'}
     ca.parse(test_dict)
     self.assertEqual(ca.get_dict(), {'test': 'foo'})
示例#8
0
 def test_multi_add(self):
     ca = ClassArg().add('test', int)
     test_dict = {'test2': 'foo'}
     ca.parse(test_dict)
     self.assertEqual(ca.get_dict(), {'test': None})
     ca.add('test2', str)
     ca.parse(test_dict)
     self.assertEqual(ca.get_dict(), {'test': None, 'test2': 'foo'})
示例#9
0
 def __init__(self, jdata):
     args = ClassArg()\
            .add('decay_steps',      int,    must = False)\
            .add('decay_rate',       float,  must = False)\
            .add('start_lr',         float,  must = True)\
            .add('stop_lr',          float,  must = False)
     self.cd = args.parse(jdata)
     self.start_lr_ = self.cd['start_lr']
示例#10
0
 def __init__(self, jdata, **kwarg):
     self.starter_learning_rate = kwarg['starter_learning_rate']
     args = ClassArg()\
         .add('start_pref_e',        float,  default = 0.02)\
         .add('limit_pref_e',        float,  default = 1.00)\
         .add('start_pref_f',        float,  default = 1000)\
         .add('limit_pref_f',        float,  default = 1.00)\
         .add('start_pref_v',        float,  default = 0)\
         .add('limit_pref_v',        float,  default = 0)\
         .add('start_pref_ae',       float,  default = 0)\
         .add('limit_pref_ae',       float,  default = 0)\
         .add('start_pref_pf',       float,  default = 0)\
         .add('limit_pref_pf',       float,  default = 0)\
         .add('relative_f',          float)
     class_data = args.parse(jdata)
     self.start_pref_e = class_data['start_pref_e']
     self.limit_pref_e = class_data['limit_pref_e']
     self.start_pref_f = class_data['start_pref_f']
     self.limit_pref_f = class_data['limit_pref_f']
     self.start_pref_v = class_data['start_pref_v']
     self.limit_pref_v = class_data['limit_pref_v']
     self.start_pref_ae = class_data['start_pref_ae']
     self.limit_pref_ae = class_data['limit_pref_ae']
     self.start_pref_pf = class_data['start_pref_pf']
     self.limit_pref_pf = class_data['limit_pref_pf']
     self.relative_f = class_data['relative_f']
     self.has_e = (self.start_pref_e != 0 or self.limit_pref_e != 0)
     self.has_f = (self.start_pref_f != 0 or self.limit_pref_f != 0)
     self.has_v = (self.start_pref_v != 0 or self.limit_pref_v != 0)
     self.has_ae = (self.start_pref_ae != 0 or self.limit_pref_ae != 0)
     self.has_pf = (self.start_pref_pf != 0 or self.limit_pref_pf != 0)
     # data required
     add_data_requirement('energy',
                          1,
                          atomic=False,
                          must=False,
                          high_prec=True)
     add_data_requirement('force',
                          3,
                          atomic=True,
                          must=False,
                          high_prec=False)
     add_data_requirement('virial',
                          9,
                          atomic=False,
                          must=False,
                          high_prec=False)
     add_data_requirement('atom_ener',
                          1,
                          atomic=True,
                          must=False,
                          high_prec=False)
     add_data_requirement('atom_pref',
                          1,
                          atomic=True,
                          must=False,
                          high_prec=False,
                          repeat=3)
示例#11
0
 def __init__(self, jdata):
     args = ClassArg()\
            .add('a',      dict,   must = True) \
            .add('r',      dict,   must = True)
     class_data = args.parse(jdata)
     self.param_a = class_data['a']
     self.param_r = class_data['r']
     self.descrpt_a = DescrptSeA(self.param_a)
     self.descrpt_r = DescrptSeR(self.param_r)
     assert (self.descrpt_a.get_ntypes() == self.descrpt_r.get_ntypes())
示例#12
0
    def __init__(self, jdata, descrpt, fitting):
        self.descrpt = descrpt
        self.rcut = self.descrpt.get_rcut()
        self.ntypes = self.descrpt.get_ntypes()
        # fitting
        self.fitting = fitting

        args = ClassArg()\
               .add('type_map',         list,   default = [])
        class_data = args.parse(jdata)
        self.type_map = class_data['type_map']
示例#13
0
    def __init__(self, jdata):
        args = ClassArg()\
               .add('sel_a',    list,   must = True) \
               .add('sel_r',    list,   must = True) \
               .add('rcut',     float,  default = 6.0) \
               .add('axis_rule',list,   must = True)
        class_data = args.parse(jdata)
        self.sel_a = class_data['sel_a']
        self.sel_r = class_data['sel_r']
        self.axis_rule = class_data['axis_rule']
        self.rcut_r = class_data['rcut']
        # ntypes and rcut_a === -1
        self.ntypes = len(self.sel_a)
        assert(self.ntypes == len(self.sel_r))
        self.rcut_a = -1
        # numb of neighbors and numb of descrptors
        self.nnei_a = np.cumsum(self.sel_a)[-1]
        self.nnei_r = np.cumsum(self.sel_r)[-1]
        self.nnei = self.nnei_a + self.nnei_r
        self.ndescrpt_a = self.nnei_a * 4
        self.ndescrpt_r = self.nnei_r * 1
        self.ndescrpt = self.ndescrpt_a + self.ndescrpt_r
        self.davg = None
        self.dstd = None

        self.place_holders = {}
        avg_zero = np.zeros([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
        std_ones = np.ones ([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
        sub_graph = tf.Graph()
        with sub_graph.as_default():
            name_pfx = 'd_lf_'
            for ii in ['coord', 'box']:
                self.place_holders[ii] = tf.placeholder(global_np_float_precision, [None, None], name = name_pfx+'t_'+ii)
            self.place_holders['type'] = tf.placeholder(tf.int32, [None, None], name=name_pfx+'t_type')
            self.place_holders['natoms_vec'] = tf.placeholder(tf.int32, [self.ntypes+2], name=name_pfx+'t_natoms')
            self.place_holders['default_mesh'] = tf.placeholder(tf.int32, [None], name=name_pfx+'t_mesh')
            self.stat_descrpt, descrpt_deriv, rij, nlist, axis, rot_mat \
                = op_module.descrpt (self.place_holders['coord'],
                                     self.place_holders['type'],
                                     self.place_holders['natoms_vec'],
                                     self.place_holders['box'],
                                     self.place_holders['default_mesh'],
                                     tf.constant(avg_zero),
                                     tf.constant(std_ones),
                                     rcut_a = self.rcut_a,
                                     rcut_r = self.rcut_r,
                                     sel_a = self.sel_a,
                                     sel_r = self.sel_r,
                                     axis_rule = self.axis_rule)
        self.sub_sess = tf.Session(graph = sub_graph, config=default_tf_session_config)
示例#14
0
    def __init__ (self, jdata, descrpt, fitting, var_name):
        self.model_type = var_name        
        self.descrpt = descrpt
        self.rcut = self.descrpt.get_rcut()
        self.ntypes = self.descrpt.get_ntypes()
        # fitting
        self.fitting = fitting

        args = ClassArg()\
               .add('type_map',         list,   default = []) \
               .add('data_stat_nbatch', int,    default = 10)
        class_data = args.parse(jdata)
        self.type_map = class_data['type_map']
        self.data_stat_nbatch = class_data['data_stat_nbatch']
示例#15
0
 def __init__(self, jdata, descrpt):
     if not isinstance(descrpt, DescrptLocFrame):
         raise RuntimeError(
             'PolarFittingLocFrame only supports DescrptLocFrame')
     self.ntypes = descrpt.get_ntypes()
     self.dim_descrpt = descrpt.get_dim_out()
     args = ClassArg()\
            .add('neuron',           list, default = [120,120,120], alias = 'n_neuron')\
            .add('resnet_dt',        bool, default = True)\
            .add('sel_type',         [list,int], default = [ii for ii in range(self.ntypes)], alias = 'pol_type')\
            .add('seed',             int)
     class_data = args.parse(jdata)
     self.n_neuron = class_data['neuron']
     self.resnet_dt = class_data['resnet_dt']
     self.sel_type = class_data['sel_type']
     self.seed = class_data['seed']
     self.useBN = False
示例#16
0
 def __init__ (self, 
               starter_learning_rate : float,
               start_pref_e : float = 0.1,
               limit_pref_e : float = 1.0,
               start_pref_ed : float = 1.0,
               limit_pref_ed : float = 1.0
 ) -> None :
     self.starter_learning_rate = kwarg['starter_learning_rate']
     args = ClassArg()\
         .add('start_pref_e',        float,  must = True, default = 0.1) \
         .add('limit_pref_e',        float,  must = True, default = 1.00)\
         .add('start_pref_ed',       float,  must = True, default = 1.00)\
         .add('limit_pref_ed',       float,  must = True, default = 1.00)
     class_data = args.parse(jdata)
     self.start_pref_e = class_data['start_pref_e']
     self.limit_pref_e = class_data['limit_pref_e']
     self.start_pref_ed = class_data['start_pref_ed']
     self.limit_pref_ed = class_data['limit_pref_ed']
     # data required
     add_data_requirement('energy', 1, atomic=False, must=True, high_prec=True)
     add_data_requirement('energy_dipole', 3, atomic=False, must=True, high_prec=False)
示例#17
0
    def __init__(self, jdata, descrpt, fitting):
        self.descrpt = descrpt
        self.rcut = self.descrpt.get_rcut()
        self.ntypes = self.descrpt.get_ntypes()
        # fitting
        self.fitting = fitting
        self.numb_fparam = self.fitting.get_numb_fparam()

        args = ClassArg()\
               .add('type_map',         list,   default = []) \
               .add('rcond',            float,  default = 1e-3) \
               .add('use_srtab',        str)
        class_data = args.parse(jdata)
        self.type_map = class_data['type_map']
        self.srtab_name = class_data['use_srtab']
        self.rcond = class_data['rcond']
        if self.srtab_name is not None:
            self.srtab = TabInter(self.srtab_name)
            args.add('smin_alpha',      float,  must = True)\
                .add('sw_rmin',         float,  must = True)\
                .add('sw_rmax',         float,  must = True)
            class_data = args.parse(jdata)
            self.smin_alpha = class_data['smin_alpha']
            self.sw_rmin = class_data['sw_rmin']
            self.sw_rmax = class_data['sw_rmax']
        else:
            self.srtab = None
示例#18
0
 def __init__(self, jdata, descrpt):
     # model param
     self.ntypes = descrpt.get_ntypes()
     self.dim_descrpt = descrpt.get_dim_out()
     args = ClassArg()\
            .add('numb_fparam',      int,    default = 0)\
            .add('neuron',           list,   default = [120,120,120], alias = 'n_neuron')\
            .add('resnet_dt',        bool,   default = True)\
            .add('seed',             int)
     class_data = args.parse(jdata)
     self.numb_fparam = class_data['numb_fparam']
     self.n_neuron = class_data['neuron']
     self.resnet_dt = class_data['resnet_dt']
     self.seed = class_data['seed']
     self.useBN = False
     # data requirement
     if self.numb_fparam > 0:
         add_data_requirement('fparam',
                              self.numb_fparam,
                              atomic=False,
                              must=False,
                              high_prec=False)
示例#19
0
 def __init__(self, jdata):
     args = ClassArg()\
            .add('sel_a',    list,   must = True) \
            .add('sel_r',    list,   must = True) \
            .add('rcut',     float,  default = 6.0) \
            .add('axis_rule',list,   must = True)
     class_data = args.parse(jdata)
     self.sel_a = class_data['sel_a']
     self.sel_r = class_data['sel_r']
     self.axis_rule = class_data['axis_rule']
     self.rcut_r = class_data['rcut']
     # ntypes and rcut_a === -1
     self.ntypes = len(self.sel_a)
     assert(self.ntypes == len(self.sel_r))
     self.rcut_a = -1
     # numb of neighbors and numb of descrptors
     self.nnei_a = np.cumsum(self.sel_a)[-1]
     self.nnei_r = np.cumsum(self.sel_r)[-1]
     self.nnei = self.nnei_a + self.nnei_r
     self.ndescrpt_a = self.nnei_a * 4
     self.ndescrpt_r = self.nnei_r * 1
     self.ndescrpt = self.ndescrpt_a + self.ndescrpt_r
示例#20
0
    def __init__(self, jdata):
        args = ClassArg()\
               .add('sel',      list,   must = True) \
               .add('rcut',     float,  default = 6.0) \
               .add('rcut_smth',float,  default = 5.5) \
               .add('neuron',   list,   default = [10, 20, 40]) \
               .add('axis_neuron', int, default = 4, alias = 'n_axis_neuron') \
               .add('resnet_dt',bool,   default = False) \
               .add('trainable',bool,   default = True) \
               .add('seed',     int) \
               .add('exclude_types', list, default = []) \
               .add('set_davg_zero', bool, default = False) \
               .add('activation_function', str,    default = 'tanh') \
               .add('precision', str, default = "default")
        class_data = args.parse(jdata)
        self.sel_a = class_data['sel']
        self.rcut_r = class_data['rcut']
        self.rcut_r_smth = class_data['rcut_smth']
        self.filter_neuron = class_data['neuron']
        self.n_axis_neuron = class_data['axis_neuron']
        self.filter_resnet_dt = class_data['resnet_dt']
        self.seed = class_data['seed']
        self.trainable = class_data['trainable']
        self.filter_activation_fn = get_activation_func(
            class_data['activation_function'])
        self.filter_precision = get_precision(class_data['precision'])
        exclude_types = class_data['exclude_types']
        self.exclude_types = set()
        for tt in exclude_types:
            assert (len(tt) == 2)
            self.exclude_types.add((tt[0], tt[1]))
            self.exclude_types.add((tt[1], tt[0]))
        self.set_davg_zero = class_data['set_davg_zero']

        # descrpt config
        self.sel_r = [0 for ii in range(len(self.sel_a))]
        self.ntypes = len(self.sel_a)
        assert (self.ntypes == len(self.sel_r))
        self.rcut_a = -1
        # numb of neighbors and numb of descrptors
        self.nnei_a = np.cumsum(self.sel_a)[-1]
        self.nnei_r = np.cumsum(self.sel_r)[-1]
        self.nnei = self.nnei_a + self.nnei_r
        self.ndescrpt_a = self.nnei_a * 4
        self.ndescrpt_r = self.nnei_r * 1
        self.ndescrpt = self.ndescrpt_a + self.ndescrpt_r
        self.useBN = False
        self.dstd = None
        self.davg = None

        self.place_holders = {}
        avg_zero = np.zeros([self.ntypes,
                             self.ndescrpt]).astype(global_np_float_precision)
        std_ones = np.ones([self.ntypes,
                            self.ndescrpt]).astype(global_np_float_precision)
        sub_graph = tf.Graph()
        with sub_graph.as_default():
            name_pfx = 'd_sea_'
            for ii in ['coord', 'box']:
                self.place_holders[ii] = tf.placeholder(
                    global_np_float_precision, [None, None],
                    name=name_pfx + 't_' + ii)
            self.place_holders['type'] = tf.placeholder(tf.int32, [None, None],
                                                        name=name_pfx +
                                                        't_type')
            self.place_holders['natoms_vec'] = tf.placeholder(
                tf.int32, [self.ntypes + 2], name=name_pfx + 't_natoms')
            self.place_holders['default_mesh'] = tf.placeholder(
                tf.int32, [None], name=name_pfx + 't_mesh')
            self.stat_descrpt, descrpt_deriv, rij, nlist \
                = op_module.descrpt_se_a(self.place_holders['coord'],
                                         self.place_holders['type'],
                                         self.place_holders['natoms_vec'],
                                         self.place_holders['box'],
                                         self.place_holders['default_mesh'],
                                         tf.constant(avg_zero),
                                         tf.constant(std_ones),
                                         rcut_a = self.rcut_a,
                                         rcut_r = self.rcut_r,
                                         rcut_r_smth = self.rcut_r_smth,
                                         sel_a = self.sel_a,
                                         sel_r = self.sel_r)
        self.sub_sess = tf.Session(graph=sub_graph,
                                   config=default_tf_session_config)
示例#21
0
 def test_add_none(self):
     ca = ClassArg().add('test', int)
     test_dict = {'test2': 'foo'}
     ca.parse(test_dict)
     self.assertEqual(ca.get_dict(), {'test': None})
示例#22
0
 def test_add_must(self):
     ca = ClassArg().add('test', str, must=True)
     test_dict = {'test2': 'foo'}
     with self.assertRaises(RuntimeError):
         ca.parse(test_dict)
示例#23
0
 def test_add_alias(self):
     ca = ClassArg().add('test', str, alias=['test1', 'test2'])
     test_dict = {'test2': 'foo'}
     ca.parse(test_dict)
     self.assertEqual(ca.get_dict(), {'test': 'foo'})
示例#24
0
 def test_add_wrong_type_cvt(self):
     ca = ClassArg().add('test', list)
     test_dict = {'test': 10}
     with self.assertRaises(TypeError):
         ca.parse(test_dict)
示例#25
0
 def test_add_type_cvt(self):
     ca = ClassArg().add('test', float)
     test_dict = {'test': '10'}
     ca.parse(test_dict)
     self.assertEqual(ca.get_dict(), {'test': 10.0})
示例#26
0
 def test_add(self):
     ca = ClassArg().add('test', int)
     test_dict = {'test': 10, 'test1': 20}
     ca.parse(test_dict)
     self.assertEqual(ca.get_dict(), {'test': 10})
示例#27
0
    def _init_param(self, jdata):
        # model config
        model_param = j_must_have(jdata, 'model')
        descrpt_param = j_must_have(model_param, 'descriptor')
        fitting_param = j_must_have(model_param, 'fitting_net')

        # descriptor
        descrpt_type = j_must_have(descrpt_param, 'type')
        if descrpt_type == 'loc_frame':
            self.descrpt = DescrptLocFrame(descrpt_param)
        elif descrpt_type == 'se_a':
            self.descrpt = DescrptSeA(descrpt_param)
        elif descrpt_type == 'se_r':
            self.descrpt = DescrptSeR(descrpt_param)
        elif descrpt_type == 'se_ar':
            self.descrpt = DescrptSeAR(descrpt_param)
        else:
            raise RuntimeError('unknow model type ' + descrpt_type)

        # fitting net
        try:
            fitting_type = fitting_param['type']
        except:
            fitting_type = 'ener'
        if fitting_type == 'ener':
            self.fitting = EnerFitting(fitting_param, self.descrpt)
        elif fitting_type == 'wfc':
            self.fitting = WFCFitting(fitting_param, self.descrpt)
        elif fitting_type == 'dipole':
            if descrpt_type == 'se_a':
                self.fitting = DipoleFittingSeA(fitting_param, self.descrpt)
            else:
                raise RuntimeError(
                    'fitting dipole only supports descrptors: se_a')
        elif fitting_type == 'polar':
            if descrpt_type == 'loc_frame':
                self.fitting = PolarFittingLocFrame(fitting_param,
                                                    self.descrpt)
            elif descrpt_type == 'se_a':
                self.fitting = PolarFittingSeA(fitting_param, self.descrpt)
            else:
                raise RuntimeError(
                    'fitting polar only supports descrptors: loc_frame and se_a'
                )
        elif fitting_type == 'global_polar':
            if descrpt_type == 'se_a':
                self.fitting = GlobalPolarFittingSeA(fitting_param,
                                                     self.descrpt)
            else:
                raise RuntimeError(
                    'fitting global_polar only supports descrptors: loc_frame and se_a'
                )
        else:
            raise RuntimeError('unknow fitting type ' + fitting_type)

        # init model
        # infer model type by fitting_type
        if fitting_type == Model.model_type:
            self.model = Model(model_param, self.descrpt, self.fitting)
        elif fitting_type == 'wfc':
            self.model = WFCModel(model_param, self.descrpt, self.fitting)
        elif fitting_type == 'dipole':
            self.model = DipoleModel(model_param, self.descrpt, self.fitting)
        elif fitting_type == 'polar':
            self.model = PolarModel(model_param, self.descrpt, self.fitting)
        elif fitting_type == 'global_polar':
            self.model = GlobalPolarModel(model_param, self.descrpt,
                                          self.fitting)
        else:
            raise RuntimeError('get unknown fitting type when building model')

        # learning rate
        lr_param = j_must_have(jdata, 'learning_rate')
        try:
            lr_type = lr_param['type']
        except:
            lr_type = 'exp'
        if lr_type == 'exp':
            self.lr = LearningRateExp(lr_param)
        else:
            raise RuntimeError('unknown learning_rate type ' + lr_type)

        # loss
        # infer loss type by fitting_type
        try:
            loss_param = jdata['loss']
            loss_type = loss_param.get('type', 'std')
        except:
            loss_param = None
            loss_type = 'std'

        if fitting_type == 'ener':
            if loss_type == 'std':
                self.loss = EnerStdLoss(
                    loss_param, starter_learning_rate=self.lr.start_lr())
            elif loss_type == 'ener_dipole':
                self.loss = EnerDipoleLoss(
                    loss_param, starter_learning_rate=self.lr.start_lr())
            else:
                raise RuntimeError('unknow loss type')
        elif fitting_type == 'wfc':
            self.loss = TensorLoss(loss_param,
                                   model=self.model,
                                   tensor_name='wfc',
                                   tensor_size=self.model.get_out_size(),
                                   label_name='wfc')
        elif fitting_type == 'dipole':
            self.loss = TensorLoss(loss_param,
                                   model=self.model,
                                   tensor_name='dipole',
                                   tensor_size=3,
                                   label_name='dipole')
        elif fitting_type == 'polar':
            self.loss = TensorLoss(loss_param,
                                   model=self.model,
                                   tensor_name='polar',
                                   tensor_size=9,
                                   label_name='polarizability')
        elif fitting_type == 'global_polar':
            self.loss = TensorLoss(loss_param,
                                   model=self.model,
                                   tensor_name='global_polar',
                                   tensor_size=9,
                                   atomic=False,
                                   label_name='polarizability')
        else:
            raise RuntimeError(
                'get unknown fitting type when building loss function')

        # training
        training_param = j_must_have(jdata, 'training')

        tr_args = ClassArg()\
                  .add('numb_test',     int,    default = 1)\
                  .add('disp_file',     str,    default = 'lcurve.out')\
                  .add('disp_freq',     int,    default = 100)\
                  .add('save_freq',     int,    default = 1000)\
                  .add('save_ckpt',     str,    default = 'model.ckpt')\
                  .add('display_in_training', bool, default = True)\
                  .add('timing_in_training',  bool, default = True)\
                  .add('profiling',     bool,   default = False)\
                  .add('profiling_file',str,    default = 'timeline.json')\
                  .add('sys_probs',   list    )\
                  .add('auto_prob_style', str, default = "prob_sys_size")
        tr_data = tr_args.parse(training_param)
        self.numb_test = tr_data['numb_test']
        self.disp_file = tr_data['disp_file']
        self.disp_freq = tr_data['disp_freq']
        self.save_freq = tr_data['save_freq']
        self.save_ckpt = tr_data['save_ckpt']
        self.display_in_training = tr_data['display_in_training']
        self.timing_in_training = tr_data['timing_in_training']
        self.profiling = tr_data['profiling']
        self.profiling_file = tr_data['profiling_file']
        self.sys_probs = tr_data['sys_probs']
        self.auto_prob_style = tr_data['auto_prob_style']
        self.useBN = False
        if fitting_type == 'ener' and self.fitting.get_numb_fparam() > 0:
            self.numb_fparam = self.fitting.get_numb_fparam()
        else:
            self.numb_fparam = 0