コード例 #1
0
def test_commit__get_model_execution_graph():

    network_filter = "mnist.*"

    for network in networks.iterator(network_filter):

        model = tensorflow.keras.models.Model(inputs=network["in"],
                                   outputs=network["out"])

        graph = kgraph.get_model_execution_graph(model)
        kgraph.print_model_execution_graph(graph)
コード例 #2
0
def test_precommit__get_model_execution_graph_resnet50():

    network_filter = "imagenet.resnet50"

    for network in networks.iterator(network_filter):

        model = tensorflow.keras.models.Model(inputs=network["in"],
                                   outputs=network["out"])

        graph = kgraph.get_model_execution_graph(model)
        kgraph.print_model_execution_graph(graph)
コード例 #3
0
ファイル: test_graph.py プロジェクト: zy009197/innvestigate
def test_fast__get_model_execution_graph_with_inputs():

    network_filter = "trivia.*:mnist.log_reg"

    for network in networks.iterator(network_filter):

        model = keras.models.Model(inputs=network["in"],
                                   outputs=network["out"])

        graph = kgraph.get_model_execution_graph(model, keep_input_layers=True)
        kgraph.print_model_execution_graph(graph)