def test_simple_pattern_replacement(): # Simple: for Extensions. Without any classes and inheritance. def pattern_replacement(): param = WrapType("opset8.Parameter") relu = WrapType("opset8.Relu", param.output(0)) def callback(m: Matcher) -> bool: root = m.get_match_root() # Just to check that capturing works and we can # link pattern nodes with matched graph nodes. assert relu in m.get_pattern_value_map() new_relu = opset8.exp( root.input_value(0)) # ot root.input(0).get_source_output() replace_node(root, new_relu) return True return Matcher(relu, "SimpleReplacement"), callback model = get_test_function() m = Manager() m.register_pass(MatcherPass(*pattern_replacement())) m.run_passes(model) assert count_ops(model, ("Relu", "Exp")) == [0, 1]
def test_matcher_pass_apply(): model = get_test_function() p = PatternReplacement() p.apply(model.get_result().input_value(0).get_node()) assert count_ops(model, "Relu") == [2]
def test_graph_rewrite(): model = get_test_function() m = Manager() # check that register pass returns pass instance anchor = m.register_pass(GraphRewrite()) anchor.add_matcher(PatternReplacement()) m.run_passes(model) assert count_ops(model, "Relu") == [2]
def test_matcher_pass(): model = get_test_function() m = Manager() # check that register pass returns pass instance p = m.register_pass(PatternReplacement()) m.run_passes(model) assert p.model_changed assert count_ops(model, "Relu") == [2]
def test_register_new_node(): class InsertExp(MatcherPass): def __init__(self): MatcherPass.__init__(self) self.model_changed = False param = WrapType("opset8.Parameter") def callback(m: Matcher) -> bool: # Input->...->Result => Input->Exp->...->Result root = m.get_match_value() consumers = root.get_target_inputs() exp = opset8.exp(root) for consumer in consumers: consumer.replace_source_output(exp.output(0)) # For testing purpose self.model_changed = True # Use new operation for additional matching self.register_new_node(exp) # Root node wasn't replaced or changed return False self.register_matcher(Matcher(param, "InsertExp"), callback) class RemoveExp(MatcherPass): def __init__(self): MatcherPass.__init__(self) self.model_changed = False param = WrapType("opset8.Exp") def callback(m: Matcher) -> bool: root = m.get_match_root() root.output(0).replace(root.input_value(0)) # For testing purpose self.model_changed = True return True self.register_matcher(Matcher(param, "RemoveExp"), callback) m = Manager() ins = m.register_pass(InsertExp()) rem = m.register_pass(RemoveExp()) m.run_passes(get_test_function()) assert ins.model_changed assert rem.model_changed
def test_serialize_pass(): core = Core() xml_path = "serialized_function.xml" bin_path = "serialized_function.bin" func = get_test_function() m = Manager() m.register_pass(Serialize(xml_path, bin_path)) m.run_passes(func) assert func is not None res_func = core.read_model(model=xml_path, weights=bin_path) assert func.get_parameters() == res_func.get_parameters() assert func.get_ordered_ops() == res_func.get_ordered_ops() os.remove(xml_path) os.remove(bin_path)
def test_model_pass(): m = Manager() p = m.register_pass(MyModelPass()) m.run_passes(get_test_function()) assert p.model_changed