예제 #1
0
    def _load_graph(frozen_graph_filename: "Path",
                    prefix: str = "load",
                    default_tf_graph: bool = False):
        # We load the protobuf file from the disk and parse it to retrieve the
        # unserialized graph_def
        with tf.gfile.GFile(str(frozen_graph_filename), "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())

            if default_tf_graph:
                tf.import_graph_def(graph_def,
                                    input_map=None,
                                    return_elements=None,
                                    name=prefix,
                                    producer_op_list=None)
                graph = tf.get_default_graph()
            else:
                # Then, we can use again a convenient built-in function to import
                # a graph_def into the  current default Graph
                with tf.Graph().as_default() as graph:
                    tf.import_graph_def(graph_def,
                                        input_map=None,
                                        return_elements=None,
                                        name=prefix,
                                        producer_op_list=None)

            return graph
예제 #2
0
    def _load_graph(self,
                    frozen_graph_filename,
                    prefix='load',
                    default_tf_graph=False):
        # We load the protobuf file from the disk and parse it to retrieve the
        # unserialized graph_def
        with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())

        if default_tf_graph:
            tf.import_graph_def(graph_def,
                                input_map=None,
                                return_elements=None,
                                name=prefix,
                                producer_op_list=None)
            graph = tf.get_default_graph()
        else:
            # Then, we can use again a convenient built-in function to import a graph_def into the
            # current default Graph
            with tf.Graph().as_default() as graph:
                tf.import_graph_def(graph_def,
                                    input_map=None,
                                    return_elements=None,
                                    name=prefix,
                                    producer_op_list=None)
        # for ii in graph.as_graph_def().node:
        #     print(ii.name)

        return graph
예제 #3
0
def freeze_graph(model_folder, output, output_node_names=None):
    # We retrieve our checkpoint fullpath
    checkpoint = tf.train.get_checkpoint_state(model_folder)
    input_checkpoint = checkpoint.model_checkpoint_path

    # We precise the file fullname of our freezed graph
    absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_folder + "/" + output

    # Before exporting our graph, we need to precise what is our output node
    # This is how TF decides what part of the Graph he has to keep and what part it can dump
    # NOTE: this variable is plural, because you can have multiple output nodes
    # output_node_names = "energy_test,force_test,virial_test,t_rcut"

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True

    # We import the meta graph and retrieve a Saver
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta',
                                       clear_devices=clear_devices)

    # We retrieve the protobuf graph definition
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()
    nodes = [n.name for n in input_graph_def.node]

    # We start a session and restore the graph weights
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)
        model_type = sess.run('model_attr/model_type:0',
                              feed_dict={}).decode('utf-8')
        if 'modifier_attr/type' in nodes:
            modifier_type = sess.run('modifier_attr/type:0',
                                     feed_dict={}).decode('utf-8')
        else:
            modifier_type = None
        if output_node_names is None:
            output_node_names = _make_node_names(model_type, modifier_type)
        print('The following nodes will be frozen: %s' % output_node_names)

        # We use a built-in TF helper to export variables to constants
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,  # The session is used to retrieve the weights
            input_graph_def,  # The graph_def is used to retrieve the nodes 
            output_node_names.split(
                ","
            )  # The output node names are used to select the usefull nodes
        )

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))
예제 #4
0
    def _setUp(self):
        run_opt = RunOptions(restart=None,
                             init_model=None,
                             log_path=None,
                             log_level=30,
                             mpi_log="master")
        jdata = j_loader(INPUT)

        # init model
        model = DPTrainer(jdata, run_opt=run_opt)
        rcut = model.model.get_rcut()

        # init data system
        systems = j_must_have(jdata['training'], 'systems')
        #systems[0] = tests_path / systems[0]
        systems = [tests_path / ii for ii in systems]
        set_pfx = j_must_have(jdata['training'], 'set_prefix')
        batch_size = j_must_have(jdata['training'], 'batch_size')
        test_size = j_must_have(jdata['training'], 'numb_test')
        data = DeepmdDataSystem(systems,
                                batch_size,
                                test_size,
                                rcut,
                                set_prefix=set_pfx)
        data.add_dict(data_requirement)

        # clear the default graph
        tf.reset_default_graph()

        # build the model with stats from the first system
        model.build(data)

        # freeze the graph
        with self.test_session() as sess:
            init_op = tf.global_variables_initializer()
            sess.run(init_op)
            graph = tf.get_default_graph()
            input_graph_def = graph.as_graph_def()
            nodes = "o_dipole,o_rmat,o_rmat_deriv,o_nlist,o_rij,descrpt_attr/rcut,descrpt_attr/ntypes,descrpt_attr/sel,descrpt_attr/ndescrpt,model_attr/tmap,model_attr/sel_type,model_attr/model_type,model_attr/output_dim,model_attr/model_version"
            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess, input_graph_def, nodes.split(","))
            output_graph = str(tests_path /
                               os.path.join(modifier_datapath, 'dipole.pb'))
            with tf.gfile.GFile(output_graph, "wb") as f:
                f.write(output_graph_def.SerializeToString())
예제 #5
0
    def _setUp(self):
        args = Args()
        run_opt = RunOptions(args, False)
        with open (args.INPUT, 'r') as fp:
           jdata = json.load (fp)

        # init model
        model = NNPTrainer (jdata, run_opt = run_opt)
        rcut = model.model.get_rcut()

        # init data system
        systems = j_must_have(jdata['training'], 'systems')
        set_pfx = j_must_have(jdata['training'], 'set_prefix')
        batch_size = j_must_have(jdata['training'], 'batch_size')
        test_size = j_must_have(jdata['training'], 'numb_test')    
        data = DeepmdDataSystem(systems, 
                                batch_size, 
                                test_size, 
                                rcut, 
                                set_prefix=set_pfx)
        data.add_dict(data_requirement)

        # clear the default graph
        tf.reset_default_graph()

        # build the model with stats from the first system
        model.build (data)
        
        # freeze the graph
        with tf.Session() as sess:
            init_op = tf.global_variables_initializer()
            sess.run(init_op)
            graph = tf.get_default_graph()
            input_graph_def = graph.as_graph_def()
            nodes = "o_dipole,o_rmat,o_rmat_deriv,o_nlist,o_rij,descrpt_attr/rcut,descrpt_attr/ntypes,descrpt_attr/sel,descrpt_attr/ndescrpt,model_attr/tmap,model_attr/sel_type,model_attr/model_type"
            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                nodes.split(",") 
            )
            output_graph = os.path.join(modifier_datapath, 'dipole.pb')
            with tf.gfile.GFile(output_graph, "wb") as f:
                f.write(output_graph_def.SerializeToString())
예제 #6
0
def freeze(*,
           checkpoint_folder: str,
           output: str,
           node_names: Optional[str] = None,
           **kwargs):
    """Freeze the graph in supplied folder.

    Parameters
    ----------
    checkpoint_folder : str
        location of the folder with model
    output : str
        output file name
    node_names : Optional[str], optional
        names of nodes to output, by default None
    """
    # We retrieve our checkpoint fullpath
    checkpoint = tf.train.get_checkpoint_state(checkpoint_folder)
    input_checkpoint = checkpoint.model_checkpoint_path

    # expand the output file to full path
    output_graph = abspath(output)

    # Before exporting our graph, we need to precise what is our output node
    # This is how TF decides what part of the Graph he has to keep
    # and what part it can dump
    # NOTE: this variable is plural, because you can have multiple output nodes
    # node_names = "energy_test,force_test,virial_test,t_rcut"

    # We clear devices to allow TensorFlow to control
    # on which device it will load operations
    clear_devices = True

    # We import the meta graph and retrieve a Saver
    try:
        # In case paralle training
        import horovod.tensorflow as _
    except ImportError:
        pass
    saver = tf.train.import_meta_graph(f"{input_checkpoint}.meta",
                                       clear_devices=clear_devices)

    # We retrieve the protobuf graph definition
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()
    nodes = [n.name for n in input_graph_def.node]

    # We start a session and restore the graph weights
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)
        model_type = run_sess(sess, "model_attr/model_type:0",
                              feed_dict={}).decode("utf-8")
        if "modifier_attr/type" in nodes:
            modifier_type = run_sess(sess,
                                     "modifier_attr/type:0",
                                     feed_dict={}).decode("utf-8")
        else:
            modifier_type = None
        if node_names is None:
            output_node_list = _make_node_names(model_type, modifier_type)
            different_set = set(output_node_list) - set(nodes)
            if different_set:
                log.warning("The following nodes are not in the graph: %s. "
                            "Skip freezeing these nodes. You may be freezing "
                            "a checkpoint generated by an old version." %
                            different_set)
                # use intersection as output list
                output_node_list = list(set(output_node_list) & set(nodes))
        else:
            output_node_list = node_names.split(",")
        log.info(f"The following nodes will be frozen: {output_node_list}")

        # We use a built-in TF helper to export variables to constants
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,  # The session is used to retrieve the weights
            input_graph_def,  # The graph_def is used to retrieve the nodes
            output_node_list,  # The output node names are used to select the usefull nodes
        )

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        log.info(f"{len(output_graph_def.node):d} ops in the final graph.")