Example #1
0
def test_any_input_predicate():
    param = opset8.parameter(PartialShape([1, 3, 22, 22]))
    slope = opset8.parameter(PartialShape([]))

    m = Matcher(AnyInput(lambda output: len(output.get_shape()) == 4), "FindActivation")
    assert m.match(param)
    assert not m.match(slope)
Example #2
0
def test_any_input():
    param = opset8.parameter(PartialShape([1, 3, 22, 22]))
    relu = opset8.relu(param.output(0))
    slope = opset8.parameter(PartialShape([]))
    prelu = opset8.prelu(param.output(0), slope.output(0))

    m = Matcher(WrapType("opset8.PRelu", [AnyInput(), AnyInput()]), "FindActivation")
    assert not m.match(relu)
    assert m.match(prelu)
Example #3
0
def test_or():
    param = opset8.parameter(PartialShape([1, 3, 22, 22]))
    relu = opset8.relu(param.output(0))
    slope = opset8.parameter(PartialShape([]))
    prelu = opset8.prelu(param.output(0), slope.output(0))

    m = Matcher(Or([WrapType("opset8.Relu"),
                    WrapType("opset8.PRelu")]), "FindActivation")
    assert m.match(relu)
    assert m.match(prelu)
Example #4
0
        def callback(m: Matcher) -> bool:
            root = m.get_match_root()

            # Just to check that capturing works and we can
            # link pattern nodes with matched graph nodes.
            assert relu in m.get_pattern_value_map()

            new_relu = opset8.exp(
                root.input_value(0))  # ot root.input(0).get_source_output()
            replace_node(root, new_relu)
            return True
        def __init__(self):
            MatcherPass.__init__(self)
            self.model_changed = False

            param = WrapType("opset8.Parameter")

            def callback(m: Matcher) -> bool:
                # Input->...->Result => Input->Exp->...->Result
                root = m.get_match_value()
                consumers = root.get_target_inputs()

                exp = opset8.exp(root)
                for consumer in consumers:
                    consumer.replace_source_output(exp.output(0))

                # For testing purpose
                self.model_changed = True

                # Use new operation for additional matching
                self.register_new_node(exp)

                # Root node wasn't replaced or changed
                return False

            self.register_matcher(Matcher(param, "InsertExp"), callback)
            def callback(m: Matcher) -> bool:
                root = m.get_match_root()
                root.output(0).replace(root.input_value(0))

                # For testing purpose
                self.model_changed = True

                return True
Example #7
0
        def callback(m: Matcher) -> bool:
            self.applied = True
            root = m.get_match_root()
            new_relu = opset8.relu(root.input(0).get_source_output())

            # For testing purpose
            self.model_changed = True

            # # Use new operation for additional matching
            # self.register_new_node(new_relu)

            # Input->Relu->Result => Input->Relu->Relu->Result
            root.input(0).replace_source_output(new_relu.output(0))
            return True
Example #8
0
def test_wrap_type_ctors():
    param = opset8.parameter(PartialShape([1, 3, 22, 22]))
    relu = opset8.relu(param.output(0))
    slope = opset8.parameter(PartialShape([]))
    prelu = opset8.prelu(param.output(0), slope.output(0))

    m = Matcher(WrapType(["opset8.Relu", "opset8.PRelu"]), "FindActivation")
    assert m.match(relu)
    assert m.match(prelu)

    m = Matcher(WrapType(["opset8.Relu", "opset8.PRelu"],
                WrapType("opset8.Parameter").output(0)), "FindActivation")
    assert m.match(relu)
        def __init__(self):
            MatcherPass.__init__(self)
            self.model_changed = False

            param = WrapType("opset8.Exp")

            def callback(m: Matcher) -> bool:
                root = m.get_match_root()
                root.output(0).replace(root.input_value(0))

                # For testing purpose
                self.model_changed = True

                return True

            self.register_matcher(Matcher(param, "RemoveExp"), callback)
            def callback(m: Matcher) -> bool:
                # Input->...->Result => Input->Exp->...->Result
                root = m.get_match_value()
                consumers = root.get_target_inputs()

                exp = opset8.exp(root)
                for consumer in consumers:
                    consumer.replace_source_output(exp.output(0))

                # For testing purpose
                self.model_changed = True

                # Use new operation for additional matching
                self.register_new_node(exp)

                # Root node wasn't replaced or changed
                return False
Example #11
0
    def pattern_replacement():
        param = WrapType("opset8.Parameter")
        relu = WrapType("opset8.Relu", param.output(0))

        def callback(m: Matcher) -> bool:
            root = m.get_match_root()

            # Just to check that capturing works and we can
            # link pattern nodes with matched graph nodes.
            assert relu in m.get_pattern_value_map()

            new_relu = opset8.exp(
                root.input_value(0))  # ot root.input(0).get_source_output()
            replace_node(root, new_relu)
            return True

        return Matcher(relu, "SimpleReplacement"), callback
Example #12
0
    def __init__(self):
        MatcherPass.__init__(self)
        self.model_changed = False

        relu = WrapType("opset8::Relu")

        def callback(m: Matcher) -> bool:
            self.applied = True
            root = m.get_match_root()
            new_relu = opset8.relu(root.input(0).get_source_output())

            # For testing purpose
            self.model_changed = True

            # # Use new operation for additional matching
            # self.register_new_node(new_relu)

            # Input->Relu->Result => Input->Relu->Relu->Result
            root.input(0).replace_source_output(new_relu.output(0))
            return True

        self.register_matcher(Matcher(relu, "PatternReplacement"), callback)
Example #13
0
def test_all_predicates():
    static_param = opset8.parameter(PartialShape([1, 3, 22, 22]), np.float32)
    dynamic_param = opset8.parameter(PartialShape([-1, 6]), np.long)
    fully_dynamic_param = opset8.parameter(PartialShape.dynamic())

    assert Matcher(WrapType("opset8.Parameter", consumers_count(0)), "Test").match(static_param)
    assert not Matcher(WrapType("opset8.Parameter", consumers_count(1)), "Test").match(static_param)

    assert Matcher(WrapType("opset8.Parameter", has_static_dim(1)), "Test").match(static_param)
    assert not Matcher(WrapType("opset8.Parameter", has_static_dim(0)), "Test").match(dynamic_param)

    assert Matcher(WrapType("opset8.Parameter", has_static_dims([0, 3])), "Test").match(static_param)
    assert not Matcher(WrapType("opset8.Parameter", has_static_dims([0, 1])), "Test").match(dynamic_param)

    assert Matcher(WrapType("opset8.Parameter", has_static_shape()), "Test").match(static_param)
    assert not Matcher(WrapType("opset8.Parameter", has_static_shape()), "Test").match(dynamic_param)

    assert Matcher(WrapType("opset8.Parameter", has_static_rank()), "Test").match(dynamic_param)
    assert not Matcher(WrapType("opset8.Parameter", has_static_rank()), "Test").match(fully_dynamic_param)

    assert Matcher(WrapType("opset8.Parameter",
                            type_matches(get_element_type(np.float32))), "Test").match(static_param)
    assert not Matcher(WrapType("opset8.Parameter",
                                type_matches(get_element_type(np.float32))), "Test").match(dynamic_param)

    assert Matcher(WrapType("opset8.Parameter",
                            type_matches_any([get_element_type(np.float32),
                                              get_element_type(np.long)])), "Test").match(static_param)
    assert Matcher(WrapType("opset8.Parameter",
                            type_matches_any([get_element_type(np.float32),
                                              get_element_type(np.long)])), "Test").match(dynamic_param)