def __init__(self, jdata, descrpt): if not isinstance(descrpt, DescrptSeA): raise RuntimeError('PolarFittingSeA only supports DescrptSeA') self.ntypes = descrpt.get_ntypes() self.dim_descrpt = descrpt.get_dim_out() args = ClassArg()\ .add('neuron', list, default = [120,120,120], alias = 'n_neuron')\ .add('resnet_dt', bool, default = True)\ .add('fit_diag', bool, default = True)\ .add('diag_shift', [list,float], default = [0.0 for ii in range(self.ntypes)])\ .add('scale', [list,float], default = [1.0 for ii in range(self.ntypes)])\ .add('sel_type', [list,int], default = [ii for ii in range(self.ntypes)], alias = 'pol_type')\ .add('seed', int)\ .add("activation_function", str , default = "tanh")\ .add('precision', str, default = "default") class_data = args.parse(jdata) self.n_neuron = class_data['neuron'] self.resnet_dt = class_data['resnet_dt'] self.sel_type = class_data['sel_type'] self.fit_diag = class_data['fit_diag'] self.seed = class_data['seed'] self.diag_shift = class_data['diag_shift'] self.scale = class_data['scale'] self.fitting_activation_fn = get_activation_func( class_data["activation_function"]) self.fitting_precision = get_precision(class_data['precision']) if type(self.sel_type) is not list: self.sel_type = [self.sel_type] if type(self.diag_shift) is not list: self.diag_shift = [self.diag_shift] if type(self.scale) is not list: self.scale = [self.scale] self.dim_rot_mat_1 = descrpt.get_dim_rot_mat_1() self.dim_rot_mat = self.dim_rot_mat_1 * 3 self.useBN = False
def __init__ (self, jdata, descrpt): if not isinstance(descrpt, DescrptLocFrame) : raise RuntimeError('WFC only supports DescrptLocFrame') self.ntypes = descrpt.get_ntypes() self.dim_descrpt = descrpt.get_dim_out() args = ClassArg()\ .add('neuron', list, default = [120,120,120], alias = 'n_neuron')\ .add('resnet_dt', bool, default = True)\ .add('wfc_numb', int, must = True)\ .add('sel_type', [list,int], default = [ii for ii in range(self.ntypes)], alias = 'wfc_type')\ .add('seed', int)\ .add("activation_function", str, default = "tanh")\ .add('precision', str, default = "default")\ .add('uniform_seed', bool, default = False) class_data = args.parse(jdata) self.n_neuron = class_data['neuron'] self.resnet_dt = class_data['resnet_dt'] self.wfc_numb = class_data['wfc_numb'] self.sel_type = class_data['sel_type'] self.seed = class_data['seed'] self.uniform_seed = class_data['uniform_seed'] self.seed_shift = one_layer_rand_seed_shift() self.fitting_activation_fn = get_activation_func(class_data["activation_function"]) self.fitting_precision = get_precision(class_data['precision']) self.useBN = False
def __init__(self, jdata, descrpt): # model param self.ntypes = descrpt.get_ntypes() self.dim_descrpt = descrpt.get_dim_out() args = ClassArg()\ .add('numb_fparam', int, default = 0)\ .add('numb_aparam', int, default = 0)\ .add('neuron', list, default = [120,120,120], alias = 'n_neuron')\ .add('resnet_dt', bool, default = True)\ .add('rcond', float, default = 1e-3) \ .add('seed', int) \ .add('atom_ener', list, default = [])\ .add("activation_function", str, default = "tanh")\ .add("precision", str, default = "default")\ .add("trainable", [list, bool], default = True) class_data = args.parse(jdata) self.numb_fparam = class_data['numb_fparam'] self.numb_aparam = class_data['numb_aparam'] self.n_neuron = class_data['neuron'] self.resnet_dt = class_data['resnet_dt'] self.rcond = class_data['rcond'] self.seed = class_data['seed'] self.fitting_activation_fn = get_activation_func( class_data["activation_function"]) self.fitting_precision = get_precision(class_data['precision']) self.trainable = class_data['trainable'] if type(self.trainable) is bool: self.trainable = [self.trainable] * (len(self.n_neuron) + 1) assert (len(self.trainable) == len(self.n_neuron) + 1), 'length of trainable should be that of n_neuron + 1' self.atom_ener = [] for at, ae in enumerate(class_data['atom_ener']): if ae is not None: self.atom_ener.append( tf.constant(ae, global_tf_float_precision, name="atom_%d_ener" % at)) else: self.atom_ener.append(None) self.useBN = False self.bias_atom_e = None # data requirement if self.numb_fparam > 0: add_data_requirement('fparam', self.numb_fparam, atomic=False, must=True, high_prec=False) self.fparam_avg = None self.fparam_std = None self.fparam_inv_std = None if self.numb_aparam > 0: add_data_requirement('aparam', self.numb_aparam, atomic=True, must=True, high_prec=False) self.aparam_avg = None self.aparam_std = None self.aparam_inv_std = None
def test_add_multi(self): ca = ClassArg()\ .add('test', int)\ .add('test1', str) test_dict = {'test': 10, 'test1': 'foo'} ca.parse(test_dict) self.assertEqual(ca.get_dict(), {'test1': 'foo', 'test': 10})
def __init__(self, jdata): args = ClassArg()\ .add('sel', list, must = True) \ .add('rcut', float, default = 6.0) \ .add('rcut_smth',float, default = 5.5) \ .add('neuron', list, default = [10, 20, 40]) \ .add('axis_neuron', int, default = 4, alias = 'n_axis_neuron') \ .add('resnet_dt',bool, default = False) \ .add('trainable',bool, default = True) \ .add('seed', int) class_data = args.parse(jdata) self.sel_a = class_data['sel'] self.rcut_r = class_data['rcut'] self.rcut_r_smth = class_data['rcut_smth'] self.filter_neuron = class_data['neuron'] self.n_axis_neuron = class_data['axis_neuron'] self.filter_resnet_dt = class_data['resnet_dt'] self.seed = class_data['seed'] self.trainable = class_data['trainable'] # descrpt config self.sel_r = [0 for ii in range(len(self.sel_a))] self.ntypes = len(self.sel_a) assert (self.ntypes == len(self.sel_r)) self.rcut_a = -1 # numb of neighbors and numb of descrptors self.nnei_a = np.cumsum(self.sel_a)[-1] self.nnei_r = np.cumsum(self.sel_r)[-1] self.nnei = self.nnei_a + self.nnei_r self.ndescrpt_a = self.nnei_a * 4 self.ndescrpt_r = self.nnei_r * 1 self.ndescrpt = self.ndescrpt_a + self.ndescrpt_r self.useBN = False
def test_add_multi_types(self): ca = ClassArg()\ .add('test', [str, list])\ .add('test1', [str, list]) test_dict = {'test': [10, 20], 'test1': 10} ca.parse(test_dict) self.assertEqual(ca.get_dict(), {'test': [10, 20], 'test1': '10'})
def test_add_default_overwrite(self): ca = ClassArg().add('test', str, alias=['test1', 'test2'], default='bar') test_dict = {'test2': 'foo'} ca.parse(test_dict) self.assertEqual(ca.get_dict(), {'test': 'foo'})
def test_multi_add(self): ca = ClassArg().add('test', int) test_dict = {'test2': 'foo'} ca.parse(test_dict) self.assertEqual(ca.get_dict(), {'test': None}) ca.add('test2', str) ca.parse(test_dict) self.assertEqual(ca.get_dict(), {'test': None, 'test2': 'foo'})
def __init__(self, jdata): args = ClassArg()\ .add('decay_steps', int, must = False)\ .add('decay_rate', float, must = False)\ .add('start_lr', float, must = True)\ .add('stop_lr', float, must = False) self.cd = args.parse(jdata) self.start_lr_ = self.cd['start_lr']
def __init__(self, jdata, **kwarg): self.starter_learning_rate = kwarg['starter_learning_rate'] args = ClassArg()\ .add('start_pref_e', float, default = 0.02)\ .add('limit_pref_e', float, default = 1.00)\ .add('start_pref_f', float, default = 1000)\ .add('limit_pref_f', float, default = 1.00)\ .add('start_pref_v', float, default = 0)\ .add('limit_pref_v', float, default = 0)\ .add('start_pref_ae', float, default = 0)\ .add('limit_pref_ae', float, default = 0)\ .add('start_pref_pf', float, default = 0)\ .add('limit_pref_pf', float, default = 0)\ .add('relative_f', float) class_data = args.parse(jdata) self.start_pref_e = class_data['start_pref_e'] self.limit_pref_e = class_data['limit_pref_e'] self.start_pref_f = class_data['start_pref_f'] self.limit_pref_f = class_data['limit_pref_f'] self.start_pref_v = class_data['start_pref_v'] self.limit_pref_v = class_data['limit_pref_v'] self.start_pref_ae = class_data['start_pref_ae'] self.limit_pref_ae = class_data['limit_pref_ae'] self.start_pref_pf = class_data['start_pref_pf'] self.limit_pref_pf = class_data['limit_pref_pf'] self.relative_f = class_data['relative_f'] self.has_e = (self.start_pref_e != 0 or self.limit_pref_e != 0) self.has_f = (self.start_pref_f != 0 or self.limit_pref_f != 0) self.has_v = (self.start_pref_v != 0 or self.limit_pref_v != 0) self.has_ae = (self.start_pref_ae != 0 or self.limit_pref_ae != 0) self.has_pf = (self.start_pref_pf != 0 or self.limit_pref_pf != 0) # data required add_data_requirement('energy', 1, atomic=False, must=False, high_prec=True) add_data_requirement('force', 3, atomic=True, must=False, high_prec=False) add_data_requirement('virial', 9, atomic=False, must=False, high_prec=False) add_data_requirement('atom_ener', 1, atomic=True, must=False, high_prec=False) add_data_requirement('atom_pref', 1, atomic=True, must=False, high_prec=False, repeat=3)
def __init__(self, jdata): args = ClassArg()\ .add('a', dict, must = True) \ .add('r', dict, must = True) class_data = args.parse(jdata) self.param_a = class_data['a'] self.param_r = class_data['r'] self.descrpt_a = DescrptSeA(self.param_a) self.descrpt_r = DescrptSeR(self.param_r) assert (self.descrpt_a.get_ntypes() == self.descrpt_r.get_ntypes())
def __init__(self, jdata, descrpt, fitting): self.descrpt = descrpt self.rcut = self.descrpt.get_rcut() self.ntypes = self.descrpt.get_ntypes() # fitting self.fitting = fitting args = ClassArg()\ .add('type_map', list, default = []) class_data = args.parse(jdata) self.type_map = class_data['type_map']
def __init__(self, jdata): args = ClassArg()\ .add('sel_a', list, must = True) \ .add('sel_r', list, must = True) \ .add('rcut', float, default = 6.0) \ .add('axis_rule',list, must = True) class_data = args.parse(jdata) self.sel_a = class_data['sel_a'] self.sel_r = class_data['sel_r'] self.axis_rule = class_data['axis_rule'] self.rcut_r = class_data['rcut'] # ntypes and rcut_a === -1 self.ntypes = len(self.sel_a) assert(self.ntypes == len(self.sel_r)) self.rcut_a = -1 # numb of neighbors and numb of descrptors self.nnei_a = np.cumsum(self.sel_a)[-1] self.nnei_r = np.cumsum(self.sel_r)[-1] self.nnei = self.nnei_a + self.nnei_r self.ndescrpt_a = self.nnei_a * 4 self.ndescrpt_r = self.nnei_r * 1 self.ndescrpt = self.ndescrpt_a + self.ndescrpt_r self.davg = None self.dstd = None self.place_holders = {} avg_zero = np.zeros([self.ntypes,self.ndescrpt]).astype(global_np_float_precision) std_ones = np.ones ([self.ntypes,self.ndescrpt]).astype(global_np_float_precision) sub_graph = tf.Graph() with sub_graph.as_default(): name_pfx = 'd_lf_' for ii in ['coord', 'box']: self.place_holders[ii] = tf.placeholder(global_np_float_precision, [None, None], name = name_pfx+'t_'+ii) self.place_holders['type'] = tf.placeholder(tf.int32, [None, None], name=name_pfx+'t_type') self.place_holders['natoms_vec'] = tf.placeholder(tf.int32, [self.ntypes+2], name=name_pfx+'t_natoms') self.place_holders['default_mesh'] = tf.placeholder(tf.int32, [None], name=name_pfx+'t_mesh') self.stat_descrpt, descrpt_deriv, rij, nlist, axis, rot_mat \ = op_module.descrpt (self.place_holders['coord'], self.place_holders['type'], self.place_holders['natoms_vec'], self.place_holders['box'], self.place_holders['default_mesh'], tf.constant(avg_zero), tf.constant(std_ones), rcut_a = self.rcut_a, rcut_r = self.rcut_r, sel_a = self.sel_a, sel_r = self.sel_r, axis_rule = self.axis_rule) self.sub_sess = tf.Session(graph = sub_graph, config=default_tf_session_config)
def __init__ (self, jdata, descrpt, fitting, var_name): self.model_type = var_name self.descrpt = descrpt self.rcut = self.descrpt.get_rcut() self.ntypes = self.descrpt.get_ntypes() # fitting self.fitting = fitting args = ClassArg()\ .add('type_map', list, default = []) \ .add('data_stat_nbatch', int, default = 10) class_data = args.parse(jdata) self.type_map = class_data['type_map'] self.data_stat_nbatch = class_data['data_stat_nbatch']
def __init__(self, jdata, descrpt): if not isinstance(descrpt, DescrptLocFrame): raise RuntimeError( 'PolarFittingLocFrame only supports DescrptLocFrame') self.ntypes = descrpt.get_ntypes() self.dim_descrpt = descrpt.get_dim_out() args = ClassArg()\ .add('neuron', list, default = [120,120,120], alias = 'n_neuron')\ .add('resnet_dt', bool, default = True)\ .add('sel_type', [list,int], default = [ii for ii in range(self.ntypes)], alias = 'pol_type')\ .add('seed', int) class_data = args.parse(jdata) self.n_neuron = class_data['neuron'] self.resnet_dt = class_data['resnet_dt'] self.sel_type = class_data['sel_type'] self.seed = class_data['seed'] self.useBN = False
def __init__ (self, starter_learning_rate : float, start_pref_e : float = 0.1, limit_pref_e : float = 1.0, start_pref_ed : float = 1.0, limit_pref_ed : float = 1.0 ) -> None : self.starter_learning_rate = kwarg['starter_learning_rate'] args = ClassArg()\ .add('start_pref_e', float, must = True, default = 0.1) \ .add('limit_pref_e', float, must = True, default = 1.00)\ .add('start_pref_ed', float, must = True, default = 1.00)\ .add('limit_pref_ed', float, must = True, default = 1.00) class_data = args.parse(jdata) self.start_pref_e = class_data['start_pref_e'] self.limit_pref_e = class_data['limit_pref_e'] self.start_pref_ed = class_data['start_pref_ed'] self.limit_pref_ed = class_data['limit_pref_ed'] # data required add_data_requirement('energy', 1, atomic=False, must=True, high_prec=True) add_data_requirement('energy_dipole', 3, atomic=False, must=True, high_prec=False)
def __init__(self, jdata, descrpt, fitting): self.descrpt = descrpt self.rcut = self.descrpt.get_rcut() self.ntypes = self.descrpt.get_ntypes() # fitting self.fitting = fitting self.numb_fparam = self.fitting.get_numb_fparam() args = ClassArg()\ .add('type_map', list, default = []) \ .add('rcond', float, default = 1e-3) \ .add('use_srtab', str) class_data = args.parse(jdata) self.type_map = class_data['type_map'] self.srtab_name = class_data['use_srtab'] self.rcond = class_data['rcond'] if self.srtab_name is not None: self.srtab = TabInter(self.srtab_name) args.add('smin_alpha', float, must = True)\ .add('sw_rmin', float, must = True)\ .add('sw_rmax', float, must = True) class_data = args.parse(jdata) self.smin_alpha = class_data['smin_alpha'] self.sw_rmin = class_data['sw_rmin'] self.sw_rmax = class_data['sw_rmax'] else: self.srtab = None
def __init__(self, jdata, descrpt): # model param self.ntypes = descrpt.get_ntypes() self.dim_descrpt = descrpt.get_dim_out() args = ClassArg()\ .add('numb_fparam', int, default = 0)\ .add('neuron', list, default = [120,120,120], alias = 'n_neuron')\ .add('resnet_dt', bool, default = True)\ .add('seed', int) class_data = args.parse(jdata) self.numb_fparam = class_data['numb_fparam'] self.n_neuron = class_data['neuron'] self.resnet_dt = class_data['resnet_dt'] self.seed = class_data['seed'] self.useBN = False # data requirement if self.numb_fparam > 0: add_data_requirement('fparam', self.numb_fparam, atomic=False, must=False, high_prec=False)
def __init__(self, jdata): args = ClassArg()\ .add('sel_a', list, must = True) \ .add('sel_r', list, must = True) \ .add('rcut', float, default = 6.0) \ .add('axis_rule',list, must = True) class_data = args.parse(jdata) self.sel_a = class_data['sel_a'] self.sel_r = class_data['sel_r'] self.axis_rule = class_data['axis_rule'] self.rcut_r = class_data['rcut'] # ntypes and rcut_a === -1 self.ntypes = len(self.sel_a) assert(self.ntypes == len(self.sel_r)) self.rcut_a = -1 # numb of neighbors and numb of descrptors self.nnei_a = np.cumsum(self.sel_a)[-1] self.nnei_r = np.cumsum(self.sel_r)[-1] self.nnei = self.nnei_a + self.nnei_r self.ndescrpt_a = self.nnei_a * 4 self.ndescrpt_r = self.nnei_r * 1 self.ndescrpt = self.ndescrpt_a + self.ndescrpt_r
def __init__(self, jdata): args = ClassArg()\ .add('sel', list, must = True) \ .add('rcut', float, default = 6.0) \ .add('rcut_smth',float, default = 5.5) \ .add('neuron', list, default = [10, 20, 40]) \ .add('axis_neuron', int, default = 4, alias = 'n_axis_neuron') \ .add('resnet_dt',bool, default = False) \ .add('trainable',bool, default = True) \ .add('seed', int) \ .add('exclude_types', list, default = []) \ .add('set_davg_zero', bool, default = False) \ .add('activation_function', str, default = 'tanh') \ .add('precision', str, default = "default") class_data = args.parse(jdata) self.sel_a = class_data['sel'] self.rcut_r = class_data['rcut'] self.rcut_r_smth = class_data['rcut_smth'] self.filter_neuron = class_data['neuron'] self.n_axis_neuron = class_data['axis_neuron'] self.filter_resnet_dt = class_data['resnet_dt'] self.seed = class_data['seed'] self.trainable = class_data['trainable'] self.filter_activation_fn = get_activation_func( class_data['activation_function']) self.filter_precision = get_precision(class_data['precision']) exclude_types = class_data['exclude_types'] self.exclude_types = set() for tt in exclude_types: assert (len(tt) == 2) self.exclude_types.add((tt[0], tt[1])) self.exclude_types.add((tt[1], tt[0])) self.set_davg_zero = class_data['set_davg_zero'] # descrpt config self.sel_r = [0 for ii in range(len(self.sel_a))] self.ntypes = len(self.sel_a) assert (self.ntypes == len(self.sel_r)) self.rcut_a = -1 # numb of neighbors and numb of descrptors self.nnei_a = np.cumsum(self.sel_a)[-1] self.nnei_r = np.cumsum(self.sel_r)[-1] self.nnei = self.nnei_a + self.nnei_r self.ndescrpt_a = self.nnei_a * 4 self.ndescrpt_r = self.nnei_r * 1 self.ndescrpt = self.ndescrpt_a + self.ndescrpt_r self.useBN = False self.dstd = None self.davg = None self.place_holders = {} avg_zero = np.zeros([self.ntypes, self.ndescrpt]).astype(global_np_float_precision) std_ones = np.ones([self.ntypes, self.ndescrpt]).astype(global_np_float_precision) sub_graph = tf.Graph() with sub_graph.as_default(): name_pfx = 'd_sea_' for ii in ['coord', 'box']: self.place_holders[ii] = tf.placeholder( global_np_float_precision, [None, None], name=name_pfx + 't_' + ii) self.place_holders['type'] = tf.placeholder(tf.int32, [None, None], name=name_pfx + 't_type') self.place_holders['natoms_vec'] = tf.placeholder( tf.int32, [self.ntypes + 2], name=name_pfx + 't_natoms') self.place_holders['default_mesh'] = tf.placeholder( tf.int32, [None], name=name_pfx + 't_mesh') self.stat_descrpt, descrpt_deriv, rij, nlist \ = op_module.descrpt_se_a(self.place_holders['coord'], self.place_holders['type'], self.place_holders['natoms_vec'], self.place_holders['box'], self.place_holders['default_mesh'], tf.constant(avg_zero), tf.constant(std_ones), rcut_a = self.rcut_a, rcut_r = self.rcut_r, rcut_r_smth = self.rcut_r_smth, sel_a = self.sel_a, sel_r = self.sel_r) self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config)
def test_add_none(self): ca = ClassArg().add('test', int) test_dict = {'test2': 'foo'} ca.parse(test_dict) self.assertEqual(ca.get_dict(), {'test': None})
def test_add_must(self): ca = ClassArg().add('test', str, must=True) test_dict = {'test2': 'foo'} with self.assertRaises(RuntimeError): ca.parse(test_dict)
def test_add_alias(self): ca = ClassArg().add('test', str, alias=['test1', 'test2']) test_dict = {'test2': 'foo'} ca.parse(test_dict) self.assertEqual(ca.get_dict(), {'test': 'foo'})
def test_add_wrong_type_cvt(self): ca = ClassArg().add('test', list) test_dict = {'test': 10} with self.assertRaises(TypeError): ca.parse(test_dict)
def test_add_type_cvt(self): ca = ClassArg().add('test', float) test_dict = {'test': '10'} ca.parse(test_dict) self.assertEqual(ca.get_dict(), {'test': 10.0})
def test_add(self): ca = ClassArg().add('test', int) test_dict = {'test': 10, 'test1': 20} ca.parse(test_dict) self.assertEqual(ca.get_dict(), {'test': 10})
def _init_param(self, jdata): # model config model_param = j_must_have(jdata, 'model') descrpt_param = j_must_have(model_param, 'descriptor') fitting_param = j_must_have(model_param, 'fitting_net') # descriptor descrpt_type = j_must_have(descrpt_param, 'type') if descrpt_type == 'loc_frame': self.descrpt = DescrptLocFrame(descrpt_param) elif descrpt_type == 'se_a': self.descrpt = DescrptSeA(descrpt_param) elif descrpt_type == 'se_r': self.descrpt = DescrptSeR(descrpt_param) elif descrpt_type == 'se_ar': self.descrpt = DescrptSeAR(descrpt_param) else: raise RuntimeError('unknow model type ' + descrpt_type) # fitting net try: fitting_type = fitting_param['type'] except: fitting_type = 'ener' if fitting_type == 'ener': self.fitting = EnerFitting(fitting_param, self.descrpt) elif fitting_type == 'wfc': self.fitting = WFCFitting(fitting_param, self.descrpt) elif fitting_type == 'dipole': if descrpt_type == 'se_a': self.fitting = DipoleFittingSeA(fitting_param, self.descrpt) else: raise RuntimeError( 'fitting dipole only supports descrptors: se_a') elif fitting_type == 'polar': if descrpt_type == 'loc_frame': self.fitting = PolarFittingLocFrame(fitting_param, self.descrpt) elif descrpt_type == 'se_a': self.fitting = PolarFittingSeA(fitting_param, self.descrpt) else: raise RuntimeError( 'fitting polar only supports descrptors: loc_frame and se_a' ) elif fitting_type == 'global_polar': if descrpt_type == 'se_a': self.fitting = GlobalPolarFittingSeA(fitting_param, self.descrpt) else: raise RuntimeError( 'fitting global_polar only supports descrptors: loc_frame and se_a' ) else: raise RuntimeError('unknow fitting type ' + fitting_type) # init model # infer model type by fitting_type if fitting_type == Model.model_type: self.model = Model(model_param, self.descrpt, self.fitting) elif fitting_type == 'wfc': self.model = WFCModel(model_param, self.descrpt, self.fitting) elif fitting_type == 'dipole': self.model = DipoleModel(model_param, self.descrpt, self.fitting) elif fitting_type == 'polar': self.model = PolarModel(model_param, self.descrpt, self.fitting) elif fitting_type == 'global_polar': self.model = GlobalPolarModel(model_param, self.descrpt, self.fitting) else: raise RuntimeError('get unknown fitting type when building model') # learning rate lr_param = j_must_have(jdata, 'learning_rate') try: lr_type = lr_param['type'] except: lr_type = 'exp' if lr_type == 'exp': self.lr = LearningRateExp(lr_param) else: raise RuntimeError('unknown learning_rate type ' + lr_type) # loss # infer loss type by fitting_type try: loss_param = jdata['loss'] loss_type = loss_param.get('type', 'std') except: loss_param = None loss_type = 'std' if fitting_type == 'ener': if loss_type == 'std': self.loss = EnerStdLoss( loss_param, starter_learning_rate=self.lr.start_lr()) elif loss_type == 'ener_dipole': self.loss = EnerDipoleLoss( loss_param, starter_learning_rate=self.lr.start_lr()) else: raise RuntimeError('unknow loss type') elif fitting_type == 'wfc': self.loss = TensorLoss(loss_param, model=self.model, tensor_name='wfc', tensor_size=self.model.get_out_size(), label_name='wfc') elif fitting_type == 'dipole': self.loss = TensorLoss(loss_param, model=self.model, tensor_name='dipole', tensor_size=3, label_name='dipole') elif fitting_type == 'polar': self.loss = TensorLoss(loss_param, model=self.model, tensor_name='polar', tensor_size=9, label_name='polarizability') elif fitting_type == 'global_polar': self.loss = TensorLoss(loss_param, model=self.model, tensor_name='global_polar', tensor_size=9, atomic=False, label_name='polarizability') else: raise RuntimeError( 'get unknown fitting type when building loss function') # training training_param = j_must_have(jdata, 'training') tr_args = ClassArg()\ .add('numb_test', int, default = 1)\ .add('disp_file', str, default = 'lcurve.out')\ .add('disp_freq', int, default = 100)\ .add('save_freq', int, default = 1000)\ .add('save_ckpt', str, default = 'model.ckpt')\ .add('display_in_training', bool, default = True)\ .add('timing_in_training', bool, default = True)\ .add('profiling', bool, default = False)\ .add('profiling_file',str, default = 'timeline.json')\ .add('sys_probs', list )\ .add('auto_prob_style', str, default = "prob_sys_size") tr_data = tr_args.parse(training_param) self.numb_test = tr_data['numb_test'] self.disp_file = tr_data['disp_file'] self.disp_freq = tr_data['disp_freq'] self.save_freq = tr_data['save_freq'] self.save_ckpt = tr_data['save_ckpt'] self.display_in_training = tr_data['display_in_training'] self.timing_in_training = tr_data['timing_in_training'] self.profiling = tr_data['profiling'] self.profiling_file = tr_data['profiling_file'] self.sys_probs = tr_data['sys_probs'] self.auto_prob_style = tr_data['auto_prob_style'] self.useBN = False if fitting_type == 'ener' and self.fitting.get_numb_fparam() > 0: self.numb_fparam = self.fitting.get_numb_fparam() else: self.numb_fparam = 0