コード例 #1
0
    def _setUp(self):
        run_opt = RunOptions(restart=None,
                             init_model=None,
                             log_path=None,
                             log_level=30,
                             mpi_log="master")
        jdata = j_loader(INPUT)

        # init model
        model = DPTrainer(jdata, run_opt=run_opt)
        rcut = model.model.get_rcut()

        # init data system
        systems = j_must_have(jdata['training'], 'systems')
        #systems[0] = tests_path / systems[0]
        systems = [tests_path / ii for ii in systems]
        set_pfx = j_must_have(jdata['training'], 'set_prefix')
        batch_size = j_must_have(jdata['training'], 'batch_size')
        test_size = j_must_have(jdata['training'], 'numb_test')
        data = DeepmdDataSystem(systems,
                                batch_size,
                                test_size,
                                rcut,
                                set_prefix=set_pfx)
        data.add_dict(data_requirement)

        # clear the default graph
        tf.reset_default_graph()

        # build the model with stats from the first system
        model.build(data)

        # freeze the graph
        with self.test_session() as sess:
            init_op = tf.global_variables_initializer()
            sess.run(init_op)
            graph = tf.get_default_graph()
            input_graph_def = graph.as_graph_def()
            nodes = "o_dipole,o_rmat,o_rmat_deriv,o_nlist,o_rij,descrpt_attr/rcut,descrpt_attr/ntypes,descrpt_attr/sel,descrpt_attr/ndescrpt,model_attr/tmap,model_attr/sel_type,model_attr/model_type,model_attr/output_dim,model_attr/model_version"
            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess, input_graph_def, nodes.split(","))
            output_graph = str(tests_path /
                               os.path.join(modifier_datapath, 'dipole.pb'))
            with tf.gfile.GFile(output_graph, "wb") as f:
                f.write(output_graph_def.SerializeToString())
コード例 #2
0
    def _setUp(self):
        args = Args()
        run_opt = RunOptions(args, False)
        with open (args.INPUT, 'r') as fp:
           jdata = json.load (fp)

        # init model
        model = NNPTrainer (jdata, run_opt = run_opt)
        rcut = model.model.get_rcut()

        # init data system
        systems = j_must_have(jdata['training'], 'systems')
        set_pfx = j_must_have(jdata['training'], 'set_prefix')
        batch_size = j_must_have(jdata['training'], 'batch_size')
        test_size = j_must_have(jdata['training'], 'numb_test')    
        data = DeepmdDataSystem(systems, 
                                batch_size, 
                                test_size, 
                                rcut, 
                                set_prefix=set_pfx)
        data.add_dict(data_requirement)

        # clear the default graph
        tf.reset_default_graph()

        # build the model with stats from the first system
        model.build (data)
        
        # freeze the graph
        with tf.Session() as sess:
            init_op = tf.global_variables_initializer()
            sess.run(init_op)
            graph = tf.get_default_graph()
            input_graph_def = graph.as_graph_def()
            nodes = "o_dipole,o_rmat,o_rmat_deriv,o_nlist,o_rij,descrpt_attr/rcut,descrpt_attr/ntypes,descrpt_attr/sel,descrpt_attr/ndescrpt,model_attr/tmap,model_attr/sel_type,model_attr/model_type"
            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                nodes.split(",") 
            )
            output_graph = os.path.join(modifier_datapath, 'dipole.pb')
            with tf.gfile.GFile(output_graph, "wb") as f:
                f.write(output_graph_def.SerializeToString())
コード例 #3
0
 def tearDown(self):
     tf.reset_default_graph()
コード例 #4
0
 def setUp(self):
     # with tf.variable_scope('load', reuse = False) :
     tf.reset_default_graph()
     self._setUp()
コード例 #5
0
 def tearDown(self):
     tf.reset_default_graph()
     if os.path.isdir(os.path.join(modifier_datapath, 'sys_test_0')):
         shutil.rmtree(os.path.join(modifier_datapath, 'sys_test_0'))
     if os.path.isfile(os.path.join(modifier_datapath, 'dipole.pb')):
         os.remove(os.path.join(modifier_datapath, 'dipole.pb'))