Exemplo n.º 1
0
def save_weights_from_checkpoint(input_checkpoint, output_path, conv_var_names=None, conv_transpose_var_names=None):
    """Save the weights of the trainable variables given a checkpoint, each one in a different file in output_path."""
    check_input_checkpoint(input_checkpoint)

    with tf.Session() as sess:
        restore_from_checkpoint(sess, input_checkpoint)
        save_weights(sess, output_path, conv_var_names=conv_var_names,
                     conv_transpose_var_names=conv_transpose_var_names)
Exemplo n.º 2
0
def save_graph_only_from_checkpoint(input_checkpoint, output_file_path, output_node_names, as_text=False):
    """Save a small version of the graph based on a checkpoint and the output node names."""
    check_input_checkpoint(input_checkpoint)

    output_node_names = output_node_names_string_as_list(output_node_names)

    with tf.Session() as sess:
        restore_from_checkpoint(sess, input_checkpoint)
        save_graph_only(sess, output_file_path, output_node_names, as_text=as_text)
Exemplo n.º 3
0
def freeze_from_checkpoint(input_checkpoint, output_file_path, output_node_names):
    """Freeze and shrink the graph based on a checkpoint and the output node names."""
    check_input_checkpoint(input_checkpoint)

    output_node_names = output_node_names_string_as_list(output_node_names)

    with tf.Session() as sess:
        restore_from_checkpoint(sess, input_checkpoint)
        freeze_graph.freeze_graph_with_def_protos(input_graph_def=sess.graph_def, input_saver_def=None,
                                                  input_checkpoint=input_checkpoint,
                                                  output_node_names=','.join(output_node_names),
                                                  restore_op_name='save/restore_all',
                                                  filename_tensor_name='save/Const:0', output_graph=output_file_path,
                                                  clear_devices=True, initializer_nodes='')