def test_low_latency2(): X = opset8.parameter(Shape([32, 40, 10]), np.float32, "X") Y = opset8.parameter(Shape([32, 40, 10]), np.float32, "Y") M = opset8.parameter(Shape([32, 2, 10]), np.float32, "M") X_i = opset8.parameter(Shape([32, 2, 10]), np.float32, "X_i") Y_i = opset8.parameter(Shape([32, 2, 10]), np.float32, "Y_i") M_body = opset8.parameter(Shape([32, 2, 10]), np.float32, "M_body") sum = opset8.add(X_i, Y_i) Zo = opset8.multiply(sum, M_body) body = Model([Zo], [X_i, Y_i, M_body], "body_function") ti = opset8.tensor_iterator() ti.set_body(body) ti.set_sliced_input(X_i, X.output(0), 0, 2, 2, 39, 1) ti.set_sliced_input(Y_i, Y.output(0), 0, 2, 2, -1, 1) ti.set_invariant_input(M_body, M.output(0)) out0 = ti.get_iter_value(Zo.output(0), -1) out1 = ti.get_concatenated_slices(Zo.output(0), 0, 2, 2, 39, 1) result0 = opset8.result(out0) result1 = opset8.result(out1) model = Model([result0, result1], [X, Y, M]) m = Manager() m.register_pass(LowLatency2()) m.run_passes(model) # TODO: create TI which will be transformed by LowLatency2 assert count_ops(model, "TensorIterator") == [1]
def test_simple_pattern_replacement(): # Simple: for Extensions. Without any classes and inheritance. def pattern_replacement(): param = WrapType("opset8.Parameter") relu = WrapType("opset8.Relu", param.output(0)) def callback(m: Matcher) -> bool: root = m.get_match_root() # Just to check that capturing works and we can # link pattern nodes with matched graph nodes. assert relu in m.get_pattern_value_map() new_relu = opset8.exp( root.input_value(0)) # ot root.input(0).get_source_output() replace_node(root, new_relu) return True return Matcher(relu, "SimpleReplacement"), callback model = get_test_function() m = Manager() m.register_pass(MatcherPass(*pattern_replacement())) m.run_passes(model) assert count_ops(model, ("Relu", "Exp")) == [0, 1]
def test_matcher_pass_apply(): model = get_test_function() p = PatternReplacement() p.apply(model.get_result().input_value(0).get_node()) assert count_ops(model, "Relu") == [2]
def test_constant_folding(): model = get_model() m = Manager() m.register_pass(ConstantFolding()) m.run_passes(model) assert model is not None assert count_ops(model, "ShapeOf") == [0]
def test_graph_rewrite(): model = get_test_function() m = Manager() # check that register pass returns pass instance anchor = m.register_pass(GraphRewrite()) anchor.add_matcher(PatternReplacement()) m.run_passes(model) assert count_ops(model, "Relu") == [2]
def test_matcher_pass(): model = get_test_function() m = Manager() # check that register pass returns pass instance p = m.register_pass(PatternReplacement()) m.run_passes(model) assert p.model_changed assert count_ops(model, "Relu") == [2]