def test_sub_variants(self): ref = Argument("base", dict, [ Argument("sub1", int), Argument("sub2", str) ], [ Variant("vnt_flag", [ Argument("type1", dict, [ Argument("shared", int), Argument("vnt1_1", int), Argument("vnt1_2", dict, [ Argument("vnt1_1_1", int) ]) ]), Argument("type2", dict, [ Argument("shared", int), Argument("vnt2_1", int), ]) ]) ]) ca = Argument("base", dict) s1 = ca.add_subfield("sub1", int) s2 = ca.add_subfield("sub2", str) v1 = ca.add_subvariant("vnt_flag") vt1 = v1.add_choice("type1", dict) vt1s0 = vt1.add_subfield("shared", int) vt1s1 = vt1.add_subfield("vnt1_1", int) vt1s2 = vt1.add_subfield("vnt1_2", dict) vt1ss = vt1s2.add_subfield("vnt1_1_1", int) vt2 = v1.add_choice("type2") vt2s0 = vt2.add_subfield("shared", int) vt2s1 = vt2.add_subfield("vnt2_1", int) self.assertTrue(ca == ref) # make sure we can modify the reference ref1 = Argument("base", dict, [ Argument("sub1", int), Argument("sub2", str) ], [ Variant("vnt_flag", [ Argument("type1", dict, [ Argument("shared", int), Argument("vnt1_1", int), Argument("vnt1_2", dict, [ Argument("vnt1_1_1", int) ]) ]), Argument("type2", dict, [ Argument("shared", int), Argument("vnt2_1", int), ]) ], optional=True, default_tag="type1") ]) v1.set_default("type1") self.assertTrue(ca == ref1) v1.set_default(False) self.assertTrue(ca == ref)
def test_idx_fields(self): s1 = Argument("sub1", int) vt1 = Argument("type1", dict, [ Argument("shared", str), Argument("vnt1_1", dict, [ Argument("vnt1_1_1", int) ]) ]) vt2 = Argument("type2", dict, [ Argument("shared", int), ]) v1 = Variant("vnt_flag", [vt1, vt2]) ca = Argument("base", dict, [s1], [v1]) self.assertTrue(ca[''] is ca) self.assertTrue(ca['.'] is ca) self.assertTrue(ca['sub1'] == ca["./sub1"] == s1) with self.assertRaises(KeyError): ca["sub2"] self.assertTrue(ca['[type1]'] is vt1) self.assertTrue(ca['[type1]///'] is vt1) self.assertTrue(ca['[type1]/vnt1_1/vnt1_1_1'] == Argument("vnt1_1_1", int)) self.assertTrue(ca['[type2]//shared'] == Argument("shared", int)) with self.assertRaises(KeyError): s1["sub1"] self.assertTrue(s1.I["sub1"] is s1) self.assertTrue(ca.I["base[type1]"] is vt1) self.assertTrue(ca.I['base[type2]//shared'] == Argument("shared", int))
def loss_variant_type_args(): doc_loss = 'The type of the loss. For fitting type `ener`, the loss type should be set to `ener` or left unset. For tensorial fitting types `dipole`, `polar` and `global_polar`, the type should be left unset.\n\.' return Variant("type", [Argument("ener", dict, loss_ener())], optional=True, default_tag='ener', doc=doc_loss)
def loss_variant_type_args(): doc_loss = 'The type of the loss. \n\.' return Variant("type", [Argument("ener", dict, loss_ener())], optional=True, default_tag='ener', doc=doc_loss)
def learning_rate_variant_type_args(): doc_lr = 'The type of the learning rate.' return Variant("type", [Argument("exp", dict, learning_rate_exp())], optional=True, default_tag='exp', doc=doc_lr)
def descrpt_variant_type_args(): link_lf = make_link('loc_frame', 'model/descriptor[loc_frame]') link_se_a = make_link('se_a', 'model/descriptor[se_a]') link_se_r = make_link('se_r', 'model/descriptor[se_r]') link_se_a_3be = make_link('se_a_3be', 'model/descriptor[se_a_3be]') link_se_a_tpe = make_link('se_a_tpe', 'model/descriptor[se_a_tpe]') link_hybrid = make_link('hybrid', 'model/descriptor[hybrid]') doc_descrpt_type = f'The type of the descritpor. See explanation below. \n\n\ - `loc_frame`: Defines a local frame at each atom, and the compute the descriptor as local coordinates under this frame.\n\n\ - `se_a`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor.\n\n\ - `se_r`: Used by the smooth edition of Deep Potential. Only the distance between atoms is used to construct the descriptor.\n\n\ - `se_a_3be`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Three-body embedding will be used by this descriptor.\n\n\ - `se_a_tpe`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Type embedding will be used by this descriptor.\n\n\ - `hybrid`: Concatenate of a list of descriptors as a new descriptor.' return Variant("type", [ Argument("loc_frame", dict, descrpt_local_frame_args()), Argument("se_a", dict, descrpt_se_a_args()), Argument("se_r", dict, descrpt_se_r_args()), Argument("se_a_3be", dict, descrpt_se_a_3be_args(), alias=['se_at']), Argument("se_a_tpe", dict, descrpt_se_a_tpe_args(), alias=['se_a_ebd' ]), Argument("hybrid", dict, descrpt_hybrid_args()), ], doc=doc_descrpt_type)
def descrpt_hybrid_args(): doc_list = f'A list of descriptor definitions' return [ Argument( "list", list, [], [ Variant("type", [ Argument("loc_frame", dict, descrpt_local_frame_args()), Argument("se_a", dict, descrpt_se_a_args()), Argument("se_r", dict, descrpt_se_r_args()), Argument("se_a_3be", dict, descrpt_se_a_3be_args(), alias=['se_at']), Argument("se_a_tpe", dict, descrpt_se_a_tpe_args(), alias=['se_a_ebd']) ]) ], repeat=True, optional=False, doc=doc_list, fold_subdoc=True) ]
def test_multi_variants(self): ca = Argument( "base", dict, [Argument("sub1", int), Argument("sub2", str)], [ Variant( "vnt_flag", [ Argument("type1", dict, [ Argument("shared", int), Argument("vnt1_1", int), Argument("vnt1_2", dict, [Argument("vnt1_1_1", int)]) ]), Argument("type2", dict, [ Argument("shared", int), Argument("vnt2_1", int), ]), Argument( "type3", dict, [Argument("vnt3_1", int)], [ # testing cascade variants here Variant("vnt3_flag1", [ Argument("v3f1t1", dict, [ Argument('v3f1t1_1', int), Argument('v3f1t1_2', int) ]), Argument("v3f1t2", dict, [Argument('v3f1t2_1', int)]) ]), Variant("vnt3_flag2", [ Argument("v3f2t1", dict, [ Argument('v3f2t1_1', int), Argument('v3f2t1_2', int) ]), Argument("v3f2t2", dict, [Argument('v3f2t2_1', int)]) ]) ]) ]) ]) docstr = ca.gen_doc() jsonstr = json.dumps(ca, cls=ArgumentEncoder)
def modifier_variant_type_args(): doc_modifier_type = "The type of modifier. See explanation below.\n\n\ -`dipole_charge`: Use WFCC to model the electronic structure of the system. Correct the long-range interaction" return Variant("type", [ Argument("dipole_charge", dict, modifier_dipole_charge()), ], optional = False, doc = doc_modifier_type)
def test_sub_variants(self): ca = Argument("base", dict, [Argument("sub1", int), Argument("sub2", str)], [ Variant("vnt_flag", [ Argument("type1", dict, [ Argument("shared", int), Argument("vnt1_1", int), Argument("vnt1_2", dict, [ Argument("vnt1_1_1", int), Argument("vnt1_1_2", int) ]) ], doc="type1 doc here!"), Argument("type2", dict, [ Argument("shared", int), Argument("vnt2_1", int), ]) ], optional=True, default_tag="type1"), Variant("vnt_flag1", [ Argument("type1", dict, [ Argument("1shared", int), Argument("1vnt1_1", int), Argument("1vnt1_2", dict, [ Argument("vnt1_1_1", int), Argument("vnt1_1_2", int) ]) ]), Argument("type2", dict, [ Argument("1shared", int), Argument("1vnt2_1", int), ]) ], optional=True, default_tag="type1", doc="another vnt") ]) docstr = ca.gen_doc(make_anchor=True) jsonstr = json.dumps(ca, cls=ArgumentEncoder)
def descrpt_variant_type_args(): doc_descrpt_type = 'The type of the descritpor. Valid types are `loc_frame`, `se_a`, `se_r` and `se_ar`. \n\n\ - `loc_frame`: Defines a local frame at each atom, and the compute the descriptor as local coordinates under this frame.\n\n\ - `se_a`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor.\n\n\ - `se_r`: Used by the smooth edition of Deep Potential. Only the distance between atoms is used to construct the descriptor.\n\n\ - `se_ar`: A hybrid of `se_a` and `se_r`. Typically `se_a` has a smaller cut-off while the `se_r` has a larger cut-off.' return Variant("type", [ Argument("loc_frame", dict, descrpt_local_frame_args()), Argument("se_a", dict, descrpt_se_a_args()), Argument("se_r", dict, descrpt_se_r_args()), Argument("se_ar", dict, descrpt_se_ar_args()) ], doc = doc_descrpt_type)
def fitting_variant_type_args(): doc_descrpt_type = 'The type of the fitting. Valid types are `ener`, `dipole`, `polar` and `global_polar`. \n\n\ - `ener`: Fit an energy model (potential energy surface).\n\n\ - `dipole`: Fit an atomic dipole model. Atomic dipole labels for all the selected atoms (see `sel_type`) should be provided by `dipole.npy` in each data system. The file has number of frames lines and 3 times of number of selected atoms columns.\n\n\ - `polar`: Fit an atomic polarizability model. Atomic polarizability labels for all the selected atoms (see `sel_type`) should be provided by `polarizability.npy` in each data system. The file has number of frames lines and 9 times of number of selected atoms columns.\n\n\ - `global_polar`: Fit a polarizability model. Polarizability labels should be provided by `polarizability.npy` in each data system. The file has number of frames lines and 9 columns.' return Variant("type", [Argument("ener", dict, fitting_ener()), Argument("dipole", dict, fitting_dipole()), Argument("polar", dict, fitting_polar()), Argument("global_polar", dict, fitting_global_polar())], optional = True, default_tag = 'ener', doc = doc_descrpt_type)
def test_idx_variants(self): vt1 = Argument("type1", dict, [ Argument("shared", int), Argument("vnt1_1", int), Argument("vnt1_2", dict, [ Argument("vnt1_1_1", int) ]) ]) vt2 = Argument("type2", dict, [ Argument("shared", int), Argument("vnt2_1", int), ]) vnt = Variant("vnt_flag", [vt1, vt2]) self.assertTrue(vnt["type1"] is vt1) self.assertTrue(vnt["type2"] is vt2) with self.assertRaises(KeyError): vnt["type3"]
def test_complicated(self): ca = Argument("base", dict, [ Argument("sub1", int, optional=True, default=1, alias=["sub1a"]), Argument( "sub2", list, [ Argument( "ss1", int, optional=True, default=21, alias=["ss1a"]) ], repeat=True, alias=["sub2a"]) ], [ Variant("vnt_flag", [ Argument("type1", dict, [ Argument("shared", int, optional=True, default=-1, alias=["shareda"]), Argument("vnt1", int, optional=True, default=111, alias=["vnt1a"]), ]), Argument("type2", dict, [ Argument("shared", int, optional=True, default=-2, alias=["sharedb"]), Argument("vnt2", int, optional=True, default=222, alias=["vnt2a"]), ], alias=['type3']) ], optional=True, default_tag="type1") ]) beg1 = {"base": {"sub2": [{}, {}]}} ref1 = { 'base': { 'sub1': 1, 'sub2': [{ 'ss1': 21 }, { 'ss1': 21 }], 'vnt_flag': "type1", 'shared': -1, 'vnt1': 111 } } self.assertDictEqual(ca.normalize(beg1), ref1) self.assertDictEqual(ca.normalize_value(beg1["base"]), ref1["base"]) beg2 = { "base": { "sub1a": 2, "sub2a": [{ "ss1a": 22 }, { "_comment1": None }], "vnt_flag": "type3", "sharedb": -3, "vnt2a": 223, "_comment2": None } } ref2 = { 'base': { 'sub1': 2, 'sub2': [{ 'ss1': 22 }, { 'ss1': 21 }], "vnt_flag": "type2", 'shared': -3, 'vnt2': 223 } } self.assertDictEqual(ca.normalize(beg2, trim_pattern="_*"), ref2) self.assertDictEqual( ca.normalize_value(beg2["base"], trim_pattern="_*"), ref2["base"]) with self.assertRaises(ValueError): ca.normalize(beg2, trim_pattern="sub*") with self.assertRaises(ValueError): ca.normalize(beg2, trim_pattern="vnt*")
def test_complicated(self): ref = Argument("base", dict, [ Argument("sub1", int), Argument("sub2", str) ], [ Variant("vnt_flag", [ Argument("type1", dict, [ Argument("shared", int), Argument("vnt1_1", int), Argument("vnt1_2", dict, [ Argument("vnt1_1_1", int) ]) ]), Argument("type2", dict, [ Argument("shared", int), Argument("vnt2_1", int), ]), Argument("type3", dict, [ Argument("vnt3_1", int) ], [ # testing cascade variants here Variant("vnt3_flag1", [ Argument("v3f1t1", dict, [ Argument('v3f1t1_1', int), Argument('v3f1t1_2', int) ]), Argument("v3f1t2", dict, [ Argument('v3f1t2_1', int) ]) ]), Variant("vnt3_flag2", [ Argument("v3f2t1", dict, [ Argument('v3f2t1_1', int), Argument('v3f2t1_2', int) ]), Argument("v3f2t2", dict, [ Argument('v3f2t2_1', int) ]) ]) ]) ]) ]) ca = Argument("base", dict) s1 = ca.add_subfield("sub1", int) s2 = ca.add_subfield("sub2", str) v1 = ca.add_subvariant("vnt_flag") vt1 = v1.add_choice("type1", dict) vt1s0 = vt1.add_subfield("shared", int) vt1s1 = vt1.add_subfield("vnt1_1", int) vt1s2 = vt1.add_subfield("vnt1_2", dict) vt1ss = vt1s2.add_subfield("vnt1_1_1", int) vt2 = v1.add_choice("type2") vt2s0 = vt2.add_subfield("shared", int) vt2s1 = vt2.add_subfield("vnt2_1", int) vt3 = v1.add_choice("type3") vt3s1 = vt3.add_subfield("vnt3_1", int) vt3f1 = vt3.add_subvariant('vnt3_flag1') vt3f1t1 = vt3f1.add_choice("v3f1t1") vt3f1t1s1 = vt3f1t1.add_subfield("v3f1t1_1", int) vt3f1t1s2 = vt3f1t1.add_subfield("v3f1t1_2", int) vt3f1t2 = vt3f1.add_choice("v3f1t2") vt3f1t2s1 = vt3f1t2.add_subfield("v3f1t2_1", int) vt3f2 = vt3.add_subvariant('vnt3_flag2') vt3f2t1 = vt3f2.add_choice("v3f2t1") vt3f2t1s1 = vt3f2t1.add_subfield("v3f2t1_1", int) vt3f2t1s2 = vt3f2t1.add_subfield("v3f2t1_2", int) vt3f2t2 = vt3f2.add_choice("v3f2t2") vt3f2t2s1 = vt3f2t2.add_subfield("v3f2t2_1", int) self.assertTrue(ca == ref) self.assertTrue(ca['[type3][vnt3_flag1=v3f1t1]'] is vt3f1t1) self.assertTrue(ca.I['base[type3][vnt3_flag1=v3f1t1]/v3f1t1_2'] is vt3f1t1s2) self.assertTrue(ca.I['base[type3][vnt3_flag1=v3f1t2]/v3f1t2_1'] is vt3f1t2s1) self.assertTrue(ca.I['base[type3][vnt3_flag2=v3f2t1]/v3f2t1_1'] is vt3f2t1s1) self.assertTrue(ca.I['base[type3][vnt3_flag2=v3f2t2]/v3f2t2_1'] is vt3f2t2s1) with self.assertRaises((KeyError, ValueError)): ca.I['base[type3][v3f2t2]'] with self.assertRaises((KeyError, ValueError)): ca.I['base[type3][vnt3_flag3=v3f2t2]/v3f2t2_1']
def test_sub_variants(self): ca = Argument( "base", dict, [Argument("sub1", int), Argument("sub2", str)], [ Variant( "vnt_flag", [ Argument("type1", dict, [ Argument("shared", int), Argument("vnt1_1", int), Argument("vnt1_2", dict, [Argument("vnt1_1_1", int)]) ]), Argument("type2", dict, [ Argument("shared", int), Argument("vnt2_1", int), ], alias=['type2a']), Argument( "type3", dict, [Argument("vnt3_1", int)], [ # testing cascade variants here Variant("vnt3_flag1", [ Argument("v3f1t1", dict, [ Argument('v3f1t1_1', int), Argument('v3f1t1_2', int) ]), Argument("v3f1t2", dict, [Argument('v3f1t2_1', int)]) ]), Variant("vnt3_flag2", [ Argument("v3f2t1", dict, [ Argument('v3f2t1_1', int), Argument('v3f2t1_2', int) ]), Argument("v3f2t2", dict, [Argument('v3f2t2_1', int)]) ]) ]) ]) ]) test_dict1 = { "base": { "sub1": 1, "sub2": "a", "vnt_flag": "type1", "shared": 10, "vnt1_1": 11, "vnt1_2": { "vnt1_1_1": 111 } } } self.assertSetEqual(set(ca.flatten_sub(test_dict1["base"]).keys()), set(test_dict1["base"].keys())) ca.check(test_dict1) test_dict2 = { "base": { "sub1": 1, "sub2": "a", "vnt_flag": "type2", "shared": 20, "vnt2_1": 21 } } self.assertSetEqual(set(ca.flatten_sub(test_dict2["base"]).keys()), set(test_dict2["base"].keys())) ca.check(test_dict2, strict=True) test_dict2["base"]["vnt_flag"] = "type2a" ca.check(test_dict2, strict=True) err_dict1 = { "base": { "sub1": 1, "sub2": "a", "vnt_flag": "type2", # here is wrong "shared": 10, "vnt1_1": 11, "vnt1_2": { "vnt1_1_1": 111 } } } with self.assertRaises(ArgumentKeyError): ca.check(err_dict1) err_dict1["base"]["vnt_flag"] = "type1" ca.check(err_dict1, strict=True) # no additional should pass err_dict1["base"]["additional"] = "hahaha" ca.check(err_dict1) # without strict should pass with self.assertRaises(ArgumentKeyError): ca.check(err_dict1, strict=True) # but should fail when strict err_dict2 = { "base": { "sub1": 1, "sub2": "a", "vnt_flag": "badtype", # here is wrong "shared": 20, "vnt2_1": 21 } } with self.assertRaises(ArgumentValueError): ca.check(err_dict2) # test optional choice test_dict1["base"].pop("vnt_flag") with self.assertRaises(ArgumentKeyError): ca.check(test_dict1) ca.sub_variants["vnt_flag"].optional = True ca.sub_variants["vnt_flag"].default_tag = "type1" ca.check(test_dict1) # test cascade variants test_dict3 = { "base": { "sub1": 1, "sub2": "a", "vnt_flag": "type3", "vnt3_1": 31, "vnt3_flag1": "v3f1t1", "vnt3_flag2": "v3f2t2", "v3f1t1_1": 3111, "v3f1t1_2": 3112, "v3f2t2_1": 3221 } } self.assertSetEqual(set(ca.flatten_sub(test_dict3["base"]).keys()), set(test_dict3["base"].keys())) ca.check(test_dict3, strict=True) test_dict3["base"].pop("vnt3_flag2") with self.assertRaises(ArgumentKeyError): ca.check(test_dict3) ca.sub_variants['vnt_flag'].choice_dict["type3"].sub_variants[ "vnt3_flag2"].optional = True ca.sub_variants['vnt_flag'].choice_dict["type3"].sub_variants[ "vnt3_flag2"].default_tag = 'v3f2t2' ca.check(test_dict3, strict=True) # make sure duplicate tag is not allowed with self.assertRaises(ValueError): Argument("base", dict, [], [ Variant("flag", [Argument("type1", dict), Argument("type1", dict)]) ]) with self.assertRaises(ValueError): Argument("base", dict, [], [ Variant("flag", [Argument("type1", dict)]), Variant("flag", [Argument("type2", dict)]) ])