def test_low_latency2():
    X = opset8.parameter(Shape([32, 40, 10]), np.float32, "X")
    Y = opset8.parameter(Shape([32, 40, 10]), np.float32, "Y")
    M = opset8.parameter(Shape([32, 2, 10]), np.float32, "M")

    X_i = opset8.parameter(Shape([32, 2, 10]), np.float32, "X_i")
    Y_i = opset8.parameter(Shape([32, 2, 10]), np.float32, "Y_i")
    M_body = opset8.parameter(Shape([32, 2, 10]), np.float32, "M_body")

    sum = opset8.add(X_i, Y_i)
    Zo = opset8.multiply(sum, M_body)

    body = Model([Zo], [X_i, Y_i, M_body], "body_function")

    ti = opset8.tensor_iterator()
    ti.set_body(body)
    ti.set_sliced_input(X_i, X.output(0), 0, 2, 2, 39, 1)
    ti.set_sliced_input(Y_i, Y.output(0), 0, 2, 2, -1, 1)
    ti.set_invariant_input(M_body, M.output(0))

    out0 = ti.get_iter_value(Zo.output(0), -1)
    out1 = ti.get_concatenated_slices(Zo.output(0), 0, 2, 2, 39, 1)

    result0 = opset8.result(out0)
    result1 = opset8.result(out1)

    model = Model([result0, result1], [X, Y, M])

    m = Manager()
    m.register_pass(LowLatency2())
    m.run_passes(model)

    # TODO: create TI which will be transformed by LowLatency2
    assert count_ops(model, "TensorIterator") == [1]
Example #2
0
def test_simple_pattern_replacement():
    # Simple: for Extensions. Without any classes and inheritance.
    def pattern_replacement():
        param = WrapType("opset8.Parameter")
        relu = WrapType("opset8.Relu", param.output(0))

        def callback(m: Matcher) -> bool:
            root = m.get_match_root()

            # Just to check that capturing works and we can
            # link pattern nodes with matched graph nodes.
            assert relu in m.get_pattern_value_map()

            new_relu = opset8.exp(
                root.input_value(0))  # ot root.input(0).get_source_output()
            replace_node(root, new_relu)
            return True

        return Matcher(relu, "SimpleReplacement"), callback

    model = get_test_function()

    m = Manager()
    m.register_pass(MatcherPass(*pattern_replacement()))
    m.run_passes(model)

    assert count_ops(model, ("Relu", "Exp")) == [0, 1]
Example #3
0
def test_matcher_pass_apply():
    model = get_test_function()

    p = PatternReplacement()
    p.apply(model.get_result().input_value(0).get_node())

    assert count_ops(model, "Relu") == [2]
def test_constant_folding():
    model = get_model()

    m = Manager()
    m.register_pass(ConstantFolding())
    m.run_passes(model)

    assert model is not None
    assert count_ops(model, "ShapeOf") == [0]
def test_graph_rewrite():
    model = get_test_function()

    m = Manager()
    # check that register pass returns pass instance
    anchor = m.register_pass(GraphRewrite())
    anchor.add_matcher(PatternReplacement())
    m.run_passes(model)

    assert count_ops(model, "Relu") == [2]
Example #6
0
def test_matcher_pass():
    model = get_test_function()

    m = Manager()
    # check that register pass returns pass instance
    p = m.register_pass(PatternReplacement())
    m.run_passes(model)

    assert p.model_changed
    assert count_ops(model, "Relu") == [2]