示例#1
0
def test_simple_pattern_replacement():
    # Simple: for Extensions. Without any classes and inheritance.
    def pattern_replacement():
        param = WrapType("opset8.Parameter")
        relu = WrapType("opset8.Relu", param.output(0))

        def callback(m: Matcher) -> bool:
            root = m.get_match_root()

            # Just to check that capturing works and we can
            # link pattern nodes with matched graph nodes.
            assert relu in m.get_pattern_value_map()

            new_relu = opset8.exp(
                root.input_value(0))  # ot root.input(0).get_source_output()
            replace_node(root, new_relu)
            return True

        return Matcher(relu, "SimpleReplacement"), callback

    model = get_test_function()

    m = Manager()
    m.register_pass(MatcherPass(*pattern_replacement()))
    m.run_passes(model)

    assert count_ops(model, ("Relu", "Exp")) == [0, 1]
示例#2
0
def test_matcher_pass_apply():
    model = get_test_function()

    p = PatternReplacement()
    p.apply(model.get_result().input_value(0).get_node())

    assert count_ops(model, "Relu") == [2]
def test_graph_rewrite():
    model = get_test_function()

    m = Manager()
    # check that register pass returns pass instance
    anchor = m.register_pass(GraphRewrite())
    anchor.add_matcher(PatternReplacement())
    m.run_passes(model)

    assert count_ops(model, "Relu") == [2]
示例#4
0
def test_matcher_pass():
    model = get_test_function()

    m = Manager()
    # check that register pass returns pass instance
    p = m.register_pass(PatternReplacement())
    m.run_passes(model)

    assert p.model_changed
    assert count_ops(model, "Relu") == [2]
def test_register_new_node():
    class InsertExp(MatcherPass):
        def __init__(self):
            MatcherPass.__init__(self)
            self.model_changed = False

            param = WrapType("opset8.Parameter")

            def callback(m: Matcher) -> bool:
                # Input->...->Result => Input->Exp->...->Result
                root = m.get_match_value()
                consumers = root.get_target_inputs()

                exp = opset8.exp(root)
                for consumer in consumers:
                    consumer.replace_source_output(exp.output(0))

                # For testing purpose
                self.model_changed = True

                # Use new operation for additional matching
                self.register_new_node(exp)

                # Root node wasn't replaced or changed
                return False

            self.register_matcher(Matcher(param, "InsertExp"), callback)

    class RemoveExp(MatcherPass):
        def __init__(self):
            MatcherPass.__init__(self)
            self.model_changed = False

            param = WrapType("opset8.Exp")

            def callback(m: Matcher) -> bool:
                root = m.get_match_root()
                root.output(0).replace(root.input_value(0))

                # For testing purpose
                self.model_changed = True

                return True

            self.register_matcher(Matcher(param, "RemoveExp"), callback)

    m = Manager()
    ins = m.register_pass(InsertExp())
    rem = m.register_pass(RemoveExp())
    m.run_passes(get_test_function())

    assert ins.model_changed
    assert rem.model_changed
def test_serialize_pass():
    core = Core()
    xml_path = "serialized_function.xml"
    bin_path = "serialized_function.bin"

    func = get_test_function()

    m = Manager()
    m.register_pass(Serialize(xml_path, bin_path))
    m.run_passes(func)

    assert func is not None

    res_func = core.read_model(model=xml_path, weights=bin_path)

    assert func.get_parameters() == res_func.get_parameters()
    assert func.get_ordered_ops() == res_func.get_ordered_ops()

    os.remove(xml_path)
    os.remove(bin_path)
示例#7
0
def test_model_pass():
    m = Manager()
    p = m.register_pass(MyModelPass())
    m.run_passes(get_test_function())

    assert p.model_changed