def _load_graph(frozen_graph_filename: "Path", prefix: str = "load", default_tf_graph: bool = False): # We load the protobuf file from the disk and parse it to retrieve the # unserialized graph_def with tf.gfile.GFile(str(frozen_graph_filename), "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) if default_tf_graph: tf.import_graph_def(graph_def, input_map=None, return_elements=None, name=prefix, producer_op_list=None) graph = tf.get_default_graph() else: # Then, we can use again a convenient built-in function to import # a graph_def into the current default Graph with tf.Graph().as_default() as graph: tf.import_graph_def(graph_def, input_map=None, return_elements=None, name=prefix, producer_op_list=None) return graph
def _load_graph(self, frozen_graph_filename, prefix='load', default_tf_graph=False): # We load the protobuf file from the disk and parse it to retrieve the # unserialized graph_def with tf.gfile.GFile(frozen_graph_filename, "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) if default_tf_graph: tf.import_graph_def(graph_def, input_map=None, return_elements=None, name=prefix, producer_op_list=None) graph = tf.get_default_graph() else: # Then, we can use again a convenient built-in function to import a graph_def into the # current default Graph with tf.Graph().as_default() as graph: tf.import_graph_def(graph_def, input_map=None, return_elements=None, name=prefix, producer_op_list=None) # for ii in graph.as_graph_def().node: # print(ii.name) return graph
def freeze_graph(model_folder, output, output_node_names=None): # We retrieve our checkpoint fullpath checkpoint = tf.train.get_checkpoint_state(model_folder) input_checkpoint = checkpoint.model_checkpoint_path # We precise the file fullname of our freezed graph absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1]) output_graph = absolute_model_folder + "/" + output # Before exporting our graph, we need to precise what is our output node # This is how TF decides what part of the Graph he has to keep and what part it can dump # NOTE: this variable is plural, because you can have multiple output nodes # output_node_names = "energy_test,force_test,virial_test,t_rcut" # We clear devices to allow TensorFlow to control on which device it will load operations clear_devices = True # We import the meta graph and retrieve a Saver saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices) # We retrieve the protobuf graph definition graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() nodes = [n.name for n in input_graph_def.node] # We start a session and restore the graph weights with tf.Session() as sess: saver.restore(sess, input_checkpoint) model_type = sess.run('model_attr/model_type:0', feed_dict={}).decode('utf-8') if 'modifier_attr/type' in nodes: modifier_type = sess.run('modifier_attr/type:0', feed_dict={}).decode('utf-8') else: modifier_type = None if output_node_names is None: output_node_names = _make_node_names(model_type, modifier_type) print('The following nodes will be frozen: %s' % output_node_names) # We use a built-in TF helper to export variables to constants output_graph_def = tf.graph_util.convert_variables_to_constants( sess, # The session is used to retrieve the weights input_graph_def, # The graph_def is used to retrieve the nodes output_node_names.split( "," ) # The output node names are used to select the usefull nodes ) # Finally we serialize and dump the output graph to the filesystem with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def _setUp(self): run_opt = RunOptions(restart=None, init_model=None, log_path=None, log_level=30, mpi_log="master") jdata = j_loader(INPUT) # init model model = DPTrainer(jdata, run_opt=run_opt) rcut = model.model.get_rcut() # init data system systems = j_must_have(jdata['training'], 'systems') #systems[0] = tests_path / systems[0] systems = [tests_path / ii for ii in systems] set_pfx = j_must_have(jdata['training'], 'set_prefix') batch_size = j_must_have(jdata['training'], 'batch_size') test_size = j_must_have(jdata['training'], 'numb_test') data = DeepmdDataSystem(systems, batch_size, test_size, rcut, set_prefix=set_pfx) data.add_dict(data_requirement) # clear the default graph tf.reset_default_graph() # build the model with stats from the first system model.build(data) # freeze the graph with self.test_session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() nodes = "o_dipole,o_rmat,o_rmat_deriv,o_nlist,o_rij,descrpt_attr/rcut,descrpt_attr/ntypes,descrpt_attr/sel,descrpt_attr/ndescrpt,model_attr/tmap,model_attr/sel_type,model_attr/model_type,model_attr/output_dim,model_attr/model_version" output_graph_def = tf.graph_util.convert_variables_to_constants( sess, input_graph_def, nodes.split(",")) output_graph = str(tests_path / os.path.join(modifier_datapath, 'dipole.pb')) with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString())
def _setUp(self): args = Args() run_opt = RunOptions(args, False) with open (args.INPUT, 'r') as fp: jdata = json.load (fp) # init model model = NNPTrainer (jdata, run_opt = run_opt) rcut = model.model.get_rcut() # init data system systems = j_must_have(jdata['training'], 'systems') set_pfx = j_must_have(jdata['training'], 'set_prefix') batch_size = j_must_have(jdata['training'], 'batch_size') test_size = j_must_have(jdata['training'], 'numb_test') data = DeepmdDataSystem(systems, batch_size, test_size, rcut, set_prefix=set_pfx) data.add_dict(data_requirement) # clear the default graph tf.reset_default_graph() # build the model with stats from the first system model.build (data) # freeze the graph with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() nodes = "o_dipole,o_rmat,o_rmat_deriv,o_nlist,o_rij,descrpt_attr/rcut,descrpt_attr/ntypes,descrpt_attr/sel,descrpt_attr/ndescrpt,model_attr/tmap,model_attr/sel_type,model_attr/model_type" output_graph_def = tf.graph_util.convert_variables_to_constants( sess, input_graph_def, nodes.split(",") ) output_graph = os.path.join(modifier_datapath, 'dipole.pb') with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString())
def freeze(*, checkpoint_folder: str, output: str, node_names: Optional[str] = None, **kwargs): """Freeze the graph in supplied folder. Parameters ---------- checkpoint_folder : str location of the folder with model output : str output file name node_names : Optional[str], optional names of nodes to output, by default None """ # We retrieve our checkpoint fullpath checkpoint = tf.train.get_checkpoint_state(checkpoint_folder) input_checkpoint = checkpoint.model_checkpoint_path # expand the output file to full path output_graph = abspath(output) # Before exporting our graph, we need to precise what is our output node # This is how TF decides what part of the Graph he has to keep # and what part it can dump # NOTE: this variable is plural, because you can have multiple output nodes # node_names = "energy_test,force_test,virial_test,t_rcut" # We clear devices to allow TensorFlow to control # on which device it will load operations clear_devices = True # We import the meta graph and retrieve a Saver try: # In case paralle training import horovod.tensorflow as _ except ImportError: pass saver = tf.train.import_meta_graph(f"{input_checkpoint}.meta", clear_devices=clear_devices) # We retrieve the protobuf graph definition graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() nodes = [n.name for n in input_graph_def.node] # We start a session and restore the graph weights with tf.Session() as sess: saver.restore(sess, input_checkpoint) model_type = run_sess(sess, "model_attr/model_type:0", feed_dict={}).decode("utf-8") if "modifier_attr/type" in nodes: modifier_type = run_sess(sess, "modifier_attr/type:0", feed_dict={}).decode("utf-8") else: modifier_type = None if node_names is None: output_node_list = _make_node_names(model_type, modifier_type) different_set = set(output_node_list) - set(nodes) if different_set: log.warning("The following nodes are not in the graph: %s. " "Skip freezeing these nodes. You may be freezing " "a checkpoint generated by an old version." % different_set) # use intersection as output list output_node_list = list(set(output_node_list) & set(nodes)) else: output_node_list = node_names.split(",") log.info(f"The following nodes will be frozen: {output_node_list}") # We use a built-in TF helper to export variables to constants output_graph_def = tf.graph_util.convert_variables_to_constants( sess, # The session is used to retrieve the weights input_graph_def, # The graph_def is used to retrieve the nodes output_node_list, # The output node names are used to select the usefull nodes ) # Finally we serialize and dump the output graph to the filesystem with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) log.info(f"{len(output_graph_def.node):d} ops in the final graph.")