Exemple #1
0
def _create_graph(graph_def: GraphDef, weight_dict: Dict[str, Tensor],
                  modifiers: Dict[str, Callable]) -> GraphDef:
    """
    Create a TF Graph from nodes

    Args:
        graph_def: TF GraphDef message containing the node graph
        weight_dict: Dictionary from node names to tensor data
        modifiers: Operations to be performed on weights before the conversion

    Raises:
        ValueError: The given graph def contains unsupported operations

    Returns:
        TF Graph for inference or saving
    """
    graph = tf.Graph()
    validate_supported_ops(graph_def)
    with tf.compat.v1.Session(graph=graph):
        for key, value in weight_dict.items():
            if key in modifiers:
                value = (modifiers[key])(value)
            weight_dict[key] = tf.convert_to_tensor(value)
        tf.graph_util.import_graph_def(graph_def, weight_dict, name='')

    graph_def = optimize_graph(graph)
    return graph_def_to_graph_v1(graph_def)
 def test_optimize_graph(self):
     """optimize_graph should replace nodes if possible"""
     # generate optimisable test model
     input_graph = testutils.get_sample_graph()
     input_ops = [node.op for node in _op_nodes(input_graph)]
     # optimise the graph model
     output_graph = optimization.optimize_graph(input_graph)
     output_ops = [node.op for node in _op_nodes(output_graph)]
     # output should differ from input and be more efficient (smaller)
     self.assertNotEqual(input_ops, output_ops)
     self.assertLess(len(output_ops), len(input_ops))
def save_tfjs_model(model: Callable, path: str) -> None:
    """Save Keras model as TFJS graph model"""
    graph = model_to_graph(model)
    graph_def = optimize_graph(graph)
    outputs = ','.join([node.name for node in get_outputs(graph_def)])
    tf.io.write_graph(graph_def, path, 'frozen_graph.pb', as_text=False)
    convert_to_tfjs([
        '--input_format=tf_frozen_model',
        '--output_format=tfjs_graph_model',
        f'--output_node_names={outputs}',
        os.path.join(path, 'frozen_graph.pb'), path
    ])