コード例 #1
0
def test_get_op_nodes_in_scope():
    model = TwoConvTestModel()
    nncf_model = NNCFNetwork(deepcopy(model), input_infos=[ModelInputInfo([1, 1, 4, 4])])  # type: NNCFNetwork
    nncf_graph = nncf_model.get_original_graph()

    # Valid scopes should be successfully found
    valid_nncf_modules = nncf_model.get_nncf_modules()
    nodes_list = list(nncf_graph.get_all_node_ids())
    for module_scope, _ in valid_nncf_modules.items():
        matching_nncf_nodes = nncf_graph.get_op_nodes_in_scope(module_scope)
        assert len(matching_nncf_nodes) == 1
        node = matching_nncf_nodes[0]
        assert isinstance(node, NNCFNode)
        assert node.node_id in nodes_list

    fake_model = BasicConvTestModel()
    fake_nncf_model = NNCFNetwork(deepcopy(fake_model), input_infos=[ModelInputInfo([1, 1, 4, 4])])

    # Not valid scopes shouldn't be found
    fake_nncf_modules = fake_nncf_model.get_nncf_modules()
    for module_scope, _ in fake_nncf_modules.items():
        matching_nncf_nodes = nncf_graph.get_op_nodes_in_scope(module_scope)
        assert not matching_nncf_nodes
コード例 #2
0
def test_check_correct_modules_replacement():
    model = TwoConvTestModel()
    nncf_model = NNCFNetwork(TwoConvTestModel(), input_infos=[ModelInputInfo([1, 1, 4, 4])])  # type: NNCFNetwork

    _, nncf_modules = check_correct_nncf_modules_replacement(model, nncf_model)
    assert set(nncf_modules) == set(nncf_model.get_nncf_modules())