def test_any_input_predicate(): param = opset8.parameter(PartialShape([1, 3, 22, 22])) slope = opset8.parameter(PartialShape([])) m = Matcher(AnyInput(lambda output: len(output.get_shape()) == 4), "FindActivation") assert m.match(param) assert not m.match(slope)
def test_any_input(): param = opset8.parameter(PartialShape([1, 3, 22, 22])) relu = opset8.relu(param.output(0)) slope = opset8.parameter(PartialShape([])) prelu = opset8.prelu(param.output(0), slope.output(0)) m = Matcher(WrapType("opset8.PRelu", [AnyInput(), AnyInput()]), "FindActivation") assert not m.match(relu) assert m.match(prelu)
def test_or(): param = opset8.parameter(PartialShape([1, 3, 22, 22])) relu = opset8.relu(param.output(0)) slope = opset8.parameter(PartialShape([])) prelu = opset8.prelu(param.output(0), slope.output(0)) m = Matcher(Or([WrapType("opset8.Relu"), WrapType("opset8.PRelu")]), "FindActivation") assert m.match(relu) assert m.match(prelu)
def callback(m: Matcher) -> bool: root = m.get_match_root() # Just to check that capturing works and we can # link pattern nodes with matched graph nodes. assert relu in m.get_pattern_value_map() new_relu = opset8.exp( root.input_value(0)) # ot root.input(0).get_source_output() replace_node(root, new_relu) return True
def __init__(self): MatcherPass.__init__(self) self.model_changed = False param = WrapType("opset8.Parameter") def callback(m: Matcher) -> bool: # Input->...->Result => Input->Exp->...->Result root = m.get_match_value() consumers = root.get_target_inputs() exp = opset8.exp(root) for consumer in consumers: consumer.replace_source_output(exp.output(0)) # For testing purpose self.model_changed = True # Use new operation for additional matching self.register_new_node(exp) # Root node wasn't replaced or changed return False self.register_matcher(Matcher(param, "InsertExp"), callback)
def callback(m: Matcher) -> bool: root = m.get_match_root() root.output(0).replace(root.input_value(0)) # For testing purpose self.model_changed = True return True
def callback(m: Matcher) -> bool: self.applied = True root = m.get_match_root() new_relu = opset8.relu(root.input(0).get_source_output()) # For testing purpose self.model_changed = True # # Use new operation for additional matching # self.register_new_node(new_relu) # Input->Relu->Result => Input->Relu->Relu->Result root.input(0).replace_source_output(new_relu.output(0)) return True
def test_wrap_type_ctors(): param = opset8.parameter(PartialShape([1, 3, 22, 22])) relu = opset8.relu(param.output(0)) slope = opset8.parameter(PartialShape([])) prelu = opset8.prelu(param.output(0), slope.output(0)) m = Matcher(WrapType(["opset8.Relu", "opset8.PRelu"]), "FindActivation") assert m.match(relu) assert m.match(prelu) m = Matcher(WrapType(["opset8.Relu", "opset8.PRelu"], WrapType("opset8.Parameter").output(0)), "FindActivation") assert m.match(relu)
def __init__(self): MatcherPass.__init__(self) self.model_changed = False param = WrapType("opset8.Exp") def callback(m: Matcher) -> bool: root = m.get_match_root() root.output(0).replace(root.input_value(0)) # For testing purpose self.model_changed = True return True self.register_matcher(Matcher(param, "RemoveExp"), callback)
def callback(m: Matcher) -> bool: # Input->...->Result => Input->Exp->...->Result root = m.get_match_value() consumers = root.get_target_inputs() exp = opset8.exp(root) for consumer in consumers: consumer.replace_source_output(exp.output(0)) # For testing purpose self.model_changed = True # Use new operation for additional matching self.register_new_node(exp) # Root node wasn't replaced or changed return False
def pattern_replacement(): param = WrapType("opset8.Parameter") relu = WrapType("opset8.Relu", param.output(0)) def callback(m: Matcher) -> bool: root = m.get_match_root() # Just to check that capturing works and we can # link pattern nodes with matched graph nodes. assert relu in m.get_pattern_value_map() new_relu = opset8.exp( root.input_value(0)) # ot root.input(0).get_source_output() replace_node(root, new_relu) return True return Matcher(relu, "SimpleReplacement"), callback
def __init__(self): MatcherPass.__init__(self) self.model_changed = False relu = WrapType("opset8::Relu") def callback(m: Matcher) -> bool: self.applied = True root = m.get_match_root() new_relu = opset8.relu(root.input(0).get_source_output()) # For testing purpose self.model_changed = True # # Use new operation for additional matching # self.register_new_node(new_relu) # Input->Relu->Result => Input->Relu->Relu->Result root.input(0).replace_source_output(new_relu.output(0)) return True self.register_matcher(Matcher(relu, "PatternReplacement"), callback)
def test_all_predicates(): static_param = opset8.parameter(PartialShape([1, 3, 22, 22]), np.float32) dynamic_param = opset8.parameter(PartialShape([-1, 6]), np.long) fully_dynamic_param = opset8.parameter(PartialShape.dynamic()) assert Matcher(WrapType("opset8.Parameter", consumers_count(0)), "Test").match(static_param) assert not Matcher(WrapType("opset8.Parameter", consumers_count(1)), "Test").match(static_param) assert Matcher(WrapType("opset8.Parameter", has_static_dim(1)), "Test").match(static_param) assert not Matcher(WrapType("opset8.Parameter", has_static_dim(0)), "Test").match(dynamic_param) assert Matcher(WrapType("opset8.Parameter", has_static_dims([0, 3])), "Test").match(static_param) assert not Matcher(WrapType("opset8.Parameter", has_static_dims([0, 1])), "Test").match(dynamic_param) assert Matcher(WrapType("opset8.Parameter", has_static_shape()), "Test").match(static_param) assert not Matcher(WrapType("opset8.Parameter", has_static_shape()), "Test").match(dynamic_param) assert Matcher(WrapType("opset8.Parameter", has_static_rank()), "Test").match(dynamic_param) assert not Matcher(WrapType("opset8.Parameter", has_static_rank()), "Test").match(fully_dynamic_param) assert Matcher(WrapType("opset8.Parameter", type_matches(get_element_type(np.float32))), "Test").match(static_param) assert not Matcher(WrapType("opset8.Parameter", type_matches(get_element_type(np.float32))), "Test").match(dynamic_param) assert Matcher(WrapType("opset8.Parameter", type_matches_any([get_element_type(np.float32), get_element_type(np.long)])), "Test").match(static_param) assert Matcher(WrapType("opset8.Parameter", type_matches_any([get_element_type(np.float32), get_element_type(np.long)])), "Test").match(dynamic_param)