예제 #1
0
 def test_sub_variants(self):
     ref = Argument("base", dict, [
         Argument("sub1", int),
         Argument("sub2", str)
     ], [
         Variant("vnt_flag", [
             Argument("type1", dict, [
                 Argument("shared", int),
                 Argument("vnt1_1", int),
                 Argument("vnt1_2", dict, [
                     Argument("vnt1_1_1", int)
                 ])
             ]),
             Argument("type2", dict, [
                 Argument("shared", int),
                 Argument("vnt2_1", int),
             ])
         ])
     ])
     ca = Argument("base", dict)
     s1 = ca.add_subfield("sub1", int)
     s2 = ca.add_subfield("sub2", str)
     v1 = ca.add_subvariant("vnt_flag")
     vt1 = v1.add_choice("type1", dict)
     vt1s0 = vt1.add_subfield("shared", int)
     vt1s1 = vt1.add_subfield("vnt1_1", int)
     vt1s2 = vt1.add_subfield("vnt1_2", dict)
     vt1ss = vt1s2.add_subfield("vnt1_1_1", int)
     vt2 = v1.add_choice("type2")
     vt2s0 = vt2.add_subfield("shared", int)
     vt2s1 = vt2.add_subfield("vnt2_1", int)
     self.assertTrue(ca == ref)
     # make sure we can modify the reference
     ref1 = Argument("base", dict, [
         Argument("sub1", int),
         Argument("sub2", str)
     ], [
         Variant("vnt_flag", [
             Argument("type1", dict, [
                 Argument("shared", int),
                 Argument("vnt1_1", int),
                 Argument("vnt1_2", dict, [
                     Argument("vnt1_1_1", int)
                 ])
             ]),
             Argument("type2", dict, [
                 Argument("shared", int),
                 Argument("vnt2_1", int),
             ])
         ], optional=True, default_tag="type1")
     ])
     v1.set_default("type1")
     self.assertTrue(ca == ref1)
     v1.set_default(False)
     self.assertTrue(ca == ref)
예제 #2
0
 def test_idx_fields(self):
     s1 = Argument("sub1", int)
     vt1 = Argument("type1", dict, [
                 Argument("shared", str),
                 Argument("vnt1_1", dict, [
                     Argument("vnt1_1_1", int)
                 ])
             ])
     vt2 = Argument("type2", dict, [
                 Argument("shared", int),
             ])
     v1 = Variant("vnt_flag", [vt1, vt2])
     ca = Argument("base", dict, [s1], [v1])
     self.assertTrue(ca[''] is ca)
     self.assertTrue(ca['.'] is ca)
     self.assertTrue(ca['sub1'] == ca["./sub1"] == s1)
     with self.assertRaises(KeyError):
         ca["sub2"]
     self.assertTrue(ca['[type1]'] is vt1)
     self.assertTrue(ca['[type1]///'] is vt1)
     self.assertTrue(ca['[type1]/vnt1_1/vnt1_1_1'] == Argument("vnt1_1_1", int))
     self.assertTrue(ca['[type2]//shared'] == Argument("shared", int))
     with self.assertRaises(KeyError):
         s1["sub1"]
     self.assertTrue(s1.I["sub1"] is s1)
     self.assertTrue(ca.I["base[type1]"] is vt1)
     self.assertTrue(ca.I['base[type2]//shared'] == Argument("shared", int))
예제 #3
0
def loss_variant_type_args():
    doc_loss = 'The type of the loss. For fitting type `ener`, the loss type should be set to `ener` or left unset. For tensorial fitting types `dipole`, `polar` and `global_polar`, the type should be left unset.\n\.'

    return Variant("type", [Argument("ener", dict, loss_ener())],
                   optional=True,
                   default_tag='ener',
                   doc=doc_loss)
예제 #4
0
def loss_variant_type_args():
    doc_loss = 'The type of the loss. \n\.'

    return Variant("type", [Argument("ener", dict, loss_ener())],
                   optional=True,
                   default_tag='ener',
                   doc=doc_loss)
예제 #5
0
def learning_rate_variant_type_args():
    doc_lr = 'The type of the learning rate.'

    return Variant("type", [Argument("exp", dict, learning_rate_exp())],
                   optional=True,
                   default_tag='exp',
                   doc=doc_lr)
예제 #6
0
def descrpt_variant_type_args():
    link_lf = make_link('loc_frame', 'model/descriptor[loc_frame]')
    link_se_a = make_link('se_a', 'model/descriptor[se_a]')
    link_se_r = make_link('se_r', 'model/descriptor[se_r]')
    link_se_a_3be = make_link('se_a_3be', 'model/descriptor[se_a_3be]')
    link_se_a_tpe = make_link('se_a_tpe', 'model/descriptor[se_a_tpe]')
    link_hybrid = make_link('hybrid', 'model/descriptor[hybrid]')
    doc_descrpt_type = f'The type of the descritpor. See explanation below. \n\n\
- `loc_frame`: Defines a local frame at each atom, and the compute the descriptor as local coordinates under this frame.\n\n\
- `se_a`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor.\n\n\
- `se_r`: Used by the smooth edition of Deep Potential. Only the distance between atoms is used to construct the descriptor.\n\n\
- `se_a_3be`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Three-body embedding will be used by this descriptor.\n\n\
- `se_a_tpe`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Type embedding will be used by this descriptor.\n\n\
- `hybrid`: Concatenate of a list of descriptors as a new descriptor.'

    return Variant("type", [
        Argument("loc_frame", dict, descrpt_local_frame_args()),
        Argument("se_a", dict, descrpt_se_a_args()),
        Argument("se_r", dict, descrpt_se_r_args()),
        Argument("se_a_3be", dict, descrpt_se_a_3be_args(), alias=['se_at']),
        Argument("se_a_tpe", dict, descrpt_se_a_tpe_args(), alias=['se_a_ebd'
                                                                   ]),
        Argument("hybrid", dict, descrpt_hybrid_args()),
    ],
                   doc=doc_descrpt_type)
예제 #7
0
def descrpt_hybrid_args():
    doc_list = f'A list of descriptor definitions'

    return [
        Argument(
            "list",
            list, [], [
                Variant("type", [
                    Argument("loc_frame", dict, descrpt_local_frame_args()),
                    Argument("se_a", dict, descrpt_se_a_args()),
                    Argument("se_r", dict, descrpt_se_r_args()),
                    Argument("se_a_3be",
                             dict,
                             descrpt_se_a_3be_args(),
                             alias=['se_at']),
                    Argument("se_a_tpe",
                             dict,
                             descrpt_se_a_tpe_args(),
                             alias=['se_a_ebd'])
                ])
            ],
            repeat=True,
            optional=False,
            doc=doc_list,
            fold_subdoc=True)
    ]
예제 #8
0
 def test_multi_variants(self):
     ca = Argument(
         "base",
         dict,
         [Argument("sub1", int),
          Argument("sub2", str)],
         [
             Variant(
                 "vnt_flag",
                 [
                     Argument("type1", dict, [
                         Argument("shared", int),
                         Argument("vnt1_1", int),
                         Argument("vnt1_2", dict,
                                  [Argument("vnt1_1_1", int)])
                     ]),
                     Argument("type2", dict, [
                         Argument("shared", int),
                         Argument("vnt2_1", int),
                     ]),
                     Argument(
                         "type3",
                         dict,
                         [Argument("vnt3_1", int)],
                         [  # testing cascade variants here
                             Variant("vnt3_flag1", [
                                 Argument("v3f1t1", dict, [
                                     Argument('v3f1t1_1', int),
                                     Argument('v3f1t1_2', int)
                                 ]),
                                 Argument("v3f1t2", dict,
                                          [Argument('v3f1t2_1', int)])
                             ]),
                             Variant("vnt3_flag2", [
                                 Argument("v3f2t1", dict, [
                                     Argument('v3f2t1_1', int),
                                     Argument('v3f2t1_2', int)
                                 ]),
                                 Argument("v3f2t2", dict,
                                          [Argument('v3f2t2_1', int)])
                             ])
                         ])
                 ])
         ])
     docstr = ca.gen_doc()
     jsonstr = json.dumps(ca, cls=ArgumentEncoder)
예제 #9
0
def modifier_variant_type_args():
    doc_modifier_type = "The type of modifier. See explanation below.\n\n\
-`dipole_charge`: Use WFCC to model the electronic structure of the system. Correct the long-range interaction"
    return Variant("type", 
                   [
                       Argument("dipole_charge", dict, modifier_dipole_charge()),
                   ],
                   optional = False,
                   doc = doc_modifier_type)
예제 #10
0
 def test_sub_variants(self):
     ca = Argument("base", dict,
                   [Argument("sub1", int),
                    Argument("sub2", str)], [
                        Variant("vnt_flag", [
                            Argument("type1",
                                     dict, [
                                         Argument("shared", int),
                                         Argument("vnt1_1", int),
                                         Argument("vnt1_2", dict, [
                                             Argument("vnt1_1_1", int),
                                             Argument("vnt1_1_2", int)
                                         ])
                                     ],
                                     doc="type1 doc here!"),
                            Argument("type2", dict, [
                                Argument("shared", int),
                                Argument("vnt2_1", int),
                            ])
                        ],
                                optional=True,
                                default_tag="type1"),
                        Variant("vnt_flag1", [
                            Argument("type1", dict, [
                                Argument("1shared", int),
                                Argument("1vnt1_1", int),
                                Argument("1vnt1_2", dict, [
                                    Argument("vnt1_1_1", int),
                                    Argument("vnt1_1_2", int)
                                ])
                            ]),
                            Argument("type2", dict, [
                                Argument("1shared", int),
                                Argument("1vnt2_1", int),
                            ])
                        ],
                                optional=True,
                                default_tag="type1",
                                doc="another vnt")
                    ])
     docstr = ca.gen_doc(make_anchor=True)
     jsonstr = json.dumps(ca, cls=ArgumentEncoder)
예제 #11
0
def descrpt_variant_type_args():
    doc_descrpt_type = 'The type of the descritpor. Valid types are `loc_frame`, `se_a`, `se_r` and `se_ar`. \n\n\
- `loc_frame`: Defines a local frame at each atom, and the compute the descriptor as local coordinates under this frame.\n\n\
- `se_a`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor.\n\n\
- `se_r`: Used by the smooth edition of Deep Potential. Only the distance between atoms is used to construct the descriptor.\n\n\
- `se_ar`: A hybrid of `se_a` and `se_r`. Typically `se_a` has a smaller cut-off while the `se_r` has a larger cut-off.'
    
    return Variant("type", [
        Argument("loc_frame", dict, descrpt_local_frame_args()),
        Argument("se_a", dict, descrpt_se_a_args()),
        Argument("se_r", dict, descrpt_se_r_args()),
        Argument("se_ar", dict, descrpt_se_ar_args())
    ], doc = doc_descrpt_type)
예제 #12
0
def fitting_variant_type_args():
    doc_descrpt_type = 'The type of the fitting. Valid types are `ener`, `dipole`, `polar` and `global_polar`. \n\n\
- `ener`: Fit an energy model (potential energy surface).\n\n\
- `dipole`: Fit an atomic dipole model. Atomic dipole labels for all the selected atoms (see `sel_type`) should be provided by `dipole.npy` in each data system. The file has number of frames lines and 3 times of number of selected atoms columns.\n\n\
- `polar`: Fit an atomic polarizability model. Atomic polarizability labels for all the selected atoms (see `sel_type`) should be provided by `polarizability.npy` in each data system. The file has number of frames lines and 9 times of number of selected atoms columns.\n\n\
- `global_polar`: Fit a polarizability model. Polarizability labels should be provided by `polarizability.npy` in each data system. The file has number of frames lines and 9 columns.'
    
    return Variant("type", [Argument("ener", dict, fitting_ener()),
                            Argument("dipole", dict, fitting_dipole()),
                            Argument("polar", dict, fitting_polar()),
                            Argument("global_polar", dict, fitting_global_polar())], 
                   optional = True,
                   default_tag = 'ener',
                   doc = doc_descrpt_type)
예제 #13
0
 def test_idx_variants(self):
     vt1 = Argument("type1", dict, [
                 Argument("shared", int),
                 Argument("vnt1_1", int),
                 Argument("vnt1_2", dict, [
                     Argument("vnt1_1_1", int)
                 ])
             ])
     vt2 = Argument("type2", dict, [
                 Argument("shared", int),
                 Argument("vnt2_1", int),
             ])
     vnt = Variant("vnt_flag", [vt1, vt2])
     self.assertTrue(vnt["type1"] is vt1)
     self.assertTrue(vnt["type2"] is vt2)
     with self.assertRaises(KeyError):
         vnt["type3"]
예제 #14
0
 def test_complicated(self):
     ca = Argument("base", dict, [
         Argument("sub1", int, optional=True, default=1, alias=["sub1a"]),
         Argument(
             "sub2",
             list, [
                 Argument(
                     "ss1", int, optional=True, default=21, alias=["ss1a"])
             ],
             repeat=True,
             alias=["sub2a"])
     ], [
         Variant("vnt_flag", [
             Argument("type1", dict, [
                 Argument("shared",
                          int,
                          optional=True,
                          default=-1,
                          alias=["shareda"]),
                 Argument("vnt1",
                          int,
                          optional=True,
                          default=111,
                          alias=["vnt1a"]),
             ]),
             Argument("type2",
                      dict, [
                          Argument("shared",
                                   int,
                                   optional=True,
                                   default=-2,
                                   alias=["sharedb"]),
                          Argument("vnt2",
                                   int,
                                   optional=True,
                                   default=222,
                                   alias=["vnt2a"]),
                      ],
                      alias=['type3'])
         ],
                 optional=True,
                 default_tag="type1")
     ])
     beg1 = {"base": {"sub2": [{}, {}]}}
     ref1 = {
         'base': {
             'sub1': 1,
             'sub2': [{
                 'ss1': 21
             }, {
                 'ss1': 21
             }],
             'vnt_flag': "type1",
             'shared': -1,
             'vnt1': 111
         }
     }
     self.assertDictEqual(ca.normalize(beg1), ref1)
     self.assertDictEqual(ca.normalize_value(beg1["base"]), ref1["base"])
     beg2 = {
         "base": {
             "sub1a": 2,
             "sub2a": [{
                 "ss1a": 22
             }, {
                 "_comment1": None
             }],
             "vnt_flag": "type3",
             "sharedb": -3,
             "vnt2a": 223,
             "_comment2": None
         }
     }
     ref2 = {
         'base': {
             'sub1': 2,
             'sub2': [{
                 'ss1': 22
             }, {
                 'ss1': 21
             }],
             "vnt_flag": "type2",
             'shared': -3,
             'vnt2': 223
         }
     }
     self.assertDictEqual(ca.normalize(beg2, trim_pattern="_*"), ref2)
     self.assertDictEqual(
         ca.normalize_value(beg2["base"], trim_pattern="_*"), ref2["base"])
     with self.assertRaises(ValueError):
         ca.normalize(beg2, trim_pattern="sub*")
     with self.assertRaises(ValueError):
         ca.normalize(beg2, trim_pattern="vnt*")
예제 #15
0
 def test_complicated(self):
     ref = Argument("base", dict, [
         Argument("sub1", int),
         Argument("sub2", str)
     ], [
         Variant("vnt_flag", [
             Argument("type1", dict, [
                 Argument("shared", int),
                 Argument("vnt1_1", int),
                 Argument("vnt1_2", dict, [
                     Argument("vnt1_1_1", int)
                 ])
             ]),
             Argument("type2", dict, [
                 Argument("shared", int),
                 Argument("vnt2_1", int),
             ]),
             Argument("type3", dict, [
                 Argument("vnt3_1", int)
             ], [ # testing cascade variants here
                 Variant("vnt3_flag1", [
                     Argument("v3f1t1", dict, [
                         Argument('v3f1t1_1', int),
                         Argument('v3f1t1_2', int)
                     ]),
                     Argument("v3f1t2", dict, [
                         Argument('v3f1t2_1', int)
                     ])
                 ]),
                 Variant("vnt3_flag2", [
                     Argument("v3f2t1", dict, [
                         Argument('v3f2t1_1', int),
                         Argument('v3f2t1_2', int)
                     ]),
                     Argument("v3f2t2", dict, [
                         Argument('v3f2t2_1', int)
                     ])
                 ])
             ])
         ])
     ])
     ca = Argument("base", dict)
     s1 = ca.add_subfield("sub1", int)
     s2 = ca.add_subfield("sub2", str)
     v1 = ca.add_subvariant("vnt_flag")
     vt1 = v1.add_choice("type1", dict)
     vt1s0 = vt1.add_subfield("shared", int)
     vt1s1 = vt1.add_subfield("vnt1_1", int)
     vt1s2 = vt1.add_subfield("vnt1_2", dict)
     vt1ss = vt1s2.add_subfield("vnt1_1_1", int)
     vt2 = v1.add_choice("type2")
     vt2s0 = vt2.add_subfield("shared", int)
     vt2s1 = vt2.add_subfield("vnt2_1", int)
     vt3 = v1.add_choice("type3")
     vt3s1 = vt3.add_subfield("vnt3_1", int)
     vt3f1 = vt3.add_subvariant('vnt3_flag1')
     vt3f1t1 = vt3f1.add_choice("v3f1t1")
     vt3f1t1s1 = vt3f1t1.add_subfield("v3f1t1_1", int)
     vt3f1t1s2 = vt3f1t1.add_subfield("v3f1t1_2", int)
     vt3f1t2 = vt3f1.add_choice("v3f1t2")
     vt3f1t2s1 = vt3f1t2.add_subfield("v3f1t2_1", int)
     vt3f2 = vt3.add_subvariant('vnt3_flag2')
     vt3f2t1 = vt3f2.add_choice("v3f2t1")
     vt3f2t1s1 = vt3f2t1.add_subfield("v3f2t1_1", int)
     vt3f2t1s2 = vt3f2t1.add_subfield("v3f2t1_2", int)
     vt3f2t2 = vt3f2.add_choice("v3f2t2")
     vt3f2t2s1 = vt3f2t2.add_subfield("v3f2t2_1", int)
     self.assertTrue(ca == ref)
     self.assertTrue(ca['[type3][vnt3_flag1=v3f1t1]'] is vt3f1t1)
     self.assertTrue(ca.I['base[type3][vnt3_flag1=v3f1t1]/v3f1t1_2'] is vt3f1t1s2)
     self.assertTrue(ca.I['base[type3][vnt3_flag1=v3f1t2]/v3f1t2_1'] is vt3f1t2s1)
     self.assertTrue(ca.I['base[type3][vnt3_flag2=v3f2t1]/v3f2t1_1'] is vt3f2t1s1)
     self.assertTrue(ca.I['base[type3][vnt3_flag2=v3f2t2]/v3f2t2_1'] is vt3f2t2s1)
     with self.assertRaises((KeyError, ValueError)):
         ca.I['base[type3][v3f2t2]']
     with self.assertRaises((KeyError, ValueError)):
         ca.I['base[type3][vnt3_flag3=v3f2t2]/v3f2t2_1']
예제 #16
0
 def test_sub_variants(self):
     ca = Argument(
         "base",
         dict,
         [Argument("sub1", int),
          Argument("sub2", str)],
         [
             Variant(
                 "vnt_flag",
                 [
                     Argument("type1", dict, [
                         Argument("shared", int),
                         Argument("vnt1_1", int),
                         Argument("vnt1_2", dict,
                                  [Argument("vnt1_1_1", int)])
                     ]),
                     Argument("type2",
                              dict, [
                                  Argument("shared", int),
                                  Argument("vnt2_1", int),
                              ],
                              alias=['type2a']),
                     Argument(
                         "type3",
                         dict,
                         [Argument("vnt3_1", int)],
                         [  # testing cascade variants here
                             Variant("vnt3_flag1", [
                                 Argument("v3f1t1", dict, [
                                     Argument('v3f1t1_1', int),
                                     Argument('v3f1t1_2', int)
                                 ]),
                                 Argument("v3f1t2", dict,
                                          [Argument('v3f1t2_1', int)])
                             ]),
                             Variant("vnt3_flag2", [
                                 Argument("v3f2t1", dict, [
                                     Argument('v3f2t1_1', int),
                                     Argument('v3f2t1_2', int)
                                 ]),
                                 Argument("v3f2t2", dict,
                                          [Argument('v3f2t2_1', int)])
                             ])
                         ])
                 ])
         ])
     test_dict1 = {
         "base": {
             "sub1": 1,
             "sub2": "a",
             "vnt_flag": "type1",
             "shared": 10,
             "vnt1_1": 11,
             "vnt1_2": {
                 "vnt1_1_1": 111
             }
         }
     }
     self.assertSetEqual(set(ca.flatten_sub(test_dict1["base"]).keys()),
                         set(test_dict1["base"].keys()))
     ca.check(test_dict1)
     test_dict2 = {
         "base": {
             "sub1": 1,
             "sub2": "a",
             "vnt_flag": "type2",
             "shared": 20,
             "vnt2_1": 21
         }
     }
     self.assertSetEqual(set(ca.flatten_sub(test_dict2["base"]).keys()),
                         set(test_dict2["base"].keys()))
     ca.check(test_dict2, strict=True)
     test_dict2["base"]["vnt_flag"] = "type2a"
     ca.check(test_dict2, strict=True)
     err_dict1 = {
         "base": {
             "sub1": 1,
             "sub2": "a",
             "vnt_flag": "type2",  # here is wrong
             "shared": 10,
             "vnt1_1": 11,
             "vnt1_2": {
                 "vnt1_1_1": 111
             }
         }
     }
     with self.assertRaises(ArgumentKeyError):
         ca.check(err_dict1)
     err_dict1["base"]["vnt_flag"] = "type1"
     ca.check(err_dict1, strict=True)  # no additional should pass
     err_dict1["base"]["additional"] = "hahaha"
     ca.check(err_dict1)  # without strict should pass
     with self.assertRaises(ArgumentKeyError):
         ca.check(err_dict1, strict=True)  # but should fail when strict
     err_dict2 = {
         "base": {
             "sub1": 1,
             "sub2": "a",
             "vnt_flag": "badtype",  # here is wrong
             "shared": 20,
             "vnt2_1": 21
         }
     }
     with self.assertRaises(ArgumentValueError):
         ca.check(err_dict2)
     # test optional choice
     test_dict1["base"].pop("vnt_flag")
     with self.assertRaises(ArgumentKeyError):
         ca.check(test_dict1)
     ca.sub_variants["vnt_flag"].optional = True
     ca.sub_variants["vnt_flag"].default_tag = "type1"
     ca.check(test_dict1)
     # test cascade variants
     test_dict3 = {
         "base": {
             "sub1": 1,
             "sub2": "a",
             "vnt_flag": "type3",
             "vnt3_1": 31,
             "vnt3_flag1": "v3f1t1",
             "vnt3_flag2": "v3f2t2",
             "v3f1t1_1": 3111,
             "v3f1t1_2": 3112,
             "v3f2t2_1": 3221
         }
     }
     self.assertSetEqual(set(ca.flatten_sub(test_dict3["base"]).keys()),
                         set(test_dict3["base"].keys()))
     ca.check(test_dict3, strict=True)
     test_dict3["base"].pop("vnt3_flag2")
     with self.assertRaises(ArgumentKeyError):
         ca.check(test_dict3)
     ca.sub_variants['vnt_flag'].choice_dict["type3"].sub_variants[
         "vnt3_flag2"].optional = True
     ca.sub_variants['vnt_flag'].choice_dict["type3"].sub_variants[
         "vnt3_flag2"].default_tag = 'v3f2t2'
     ca.check(test_dict3, strict=True)
     # make sure duplicate tag is not allowed
     with self.assertRaises(ValueError):
         Argument("base", dict, [], [
             Variant("flag",
                     [Argument("type1", dict),
                      Argument("type1", dict)])
         ])
     with self.assertRaises(ValueError):
         Argument("base", dict, [], [
             Variant("flag", [Argument("type1", dict)]),
             Variant("flag", [Argument("type2", dict)])
         ])