def create_h_sigmoid_act() -> GraphPattern: main_pattern = GraphPattern() # ReLU version pattern = GraphPattern() input_pattern_node = pattern.add_node( label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE) add_node = pattern.add_node(label='ADD', type='AddV2') relu_node = pattern.add_node(label='RELU', type='ReLU') mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul') pattern.add_edge(input_pattern_node, add_node) pattern.add_edge(add_node, relu_node) pattern.add_edge(relu_node, mul_node) main_pattern.add_pattern_alternative(pattern) # ReLU6 version pattern = GraphPattern() input_pattern_node = pattern.add_node( label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE) add_node = pattern.add_node(label='ADD', type='AddV2') relu6_node = pattern.add_node(label='RELU6', type='Relu6') mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul') pattern.add_edge(input_pattern_node, add_node) pattern.add_edge(add_node, relu6_node) pattern.add_edge(relu6_node, mul_node) main_pattern.add_pattern_alternative(pattern) return main_pattern
def test_join_patterns_func(): ref_pattern = GraphPattern() ref_pattern.add_node(label='first', type=TestPattern.first_type) added_node = ref_pattern.add_node(label='second', type=TestPattern.second_type) for node in ref_pattern.graph.nodes: if node != added_node: ref_pattern.add_edge(node, added_node) first_nodes = list(TestPattern.first_pattern.graph.nodes) second_nodes = list(TestPattern.second_pattern.graph.nodes) edges = list(itertools.product(first_nodes, second_nodes)) pattern = copy.copy(TestPattern.first_pattern) pattern.join_patterns(TestPattern.second_pattern, edges) assert ref_pattern == pattern
def test_ops_combination_two_patterns(): pattern = TestPattern.first_pattern + TestPattern.second_pattern ref_pattern = GraphPattern() ref_pattern.add_node(label='first', type=TestPattern.first_type) added_node = ref_pattern.add_node(label='second', type=TestPattern.second_type) for node in ref_pattern.graph.nodes: if node != added_node: ref_pattern.add_edge(node, added_node) assert ref_pattern == pattern pattern = TestPattern.first_pattern | TestPattern.second_pattern ref_pattern = GraphPattern() ref_pattern.add_node(label='first', type=TestPattern.first_type) _ = ref_pattern.add_node(label='second', type=TestPattern.second_type) assert ref_pattern == pattern
def test_join_pattern_with_special_input_node(): pattern = TestPattern.first_pattern second_pattern = GraphPattern() second_pattern.add_node(label='second', type=GraphPattern.ANY_PATTERN_NODE_TYPE) pattern.join_patterns(second_pattern) pattern.join_patterns(TestPattern.third_pattern) ref_pattern = GraphPattern() ref_pattern.add_node(label='first', type=TestPattern.first_type) added_node = ref_pattern.add_node(label='third', type=TestPattern.third_type) for node in ref_pattern.graph.nodes: if node != added_node: ref_pattern.add_edge(node, added_node) assert pattern == ref_pattern
def test_join_patterns_func_three_patterns(): pattern = (TestPattern.first_pattern + TestPattern.second_pattern + TestPattern.third_pattern) pattern_nodes = list(pattern.graph.nodes) third_nodes = list(TestPattern.third_pattern.graph.nodes) edges = list(itertools.product(pattern_nodes, third_nodes)) pattern.join_patterns(TestPattern.third_pattern, edges) ref_pattern = GraphPattern() _ = ref_pattern.add_node(label='first', type=TestPattern.first_type) added_node = ref_pattern.add_node(label='second', type=TestPattern.second_type) for node in ref_pattern.graph.nodes: if node != added_node: ref_pattern.add_edge(node, added_node) last_node = list(nx.topological_sort(ref_pattern.graph))[-1] added_node = ref_pattern.add_node(label='third', type=TestPattern.third_type) ref_pattern.add_edge(last_node, added_node) added_node = ref_pattern.add_node(label='third', type=TestPattern.third_type) for node in ref_pattern.graph.nodes: if node != added_node: ref_pattern.add_edge(node, added_node) assert ref_pattern == pattern
def test_ops_combination_three_patterns(): pattern = TestPattern.first_pattern + TestPattern.second_pattern | TestPattern.third_pattern ref_pattern = GraphPattern() ref_pattern.add_node(label='first', type=TestPattern.first_type) added_node = ref_pattern.add_node(label='second', type=TestPattern.second_type) for node in ref_pattern.graph.nodes: if node != added_node: ref_pattern.add_edge(node, added_node) _ = ref_pattern.add_node(label='third', type=TestPattern.third_type) assert ref_pattern == pattern pattern = TestPattern.first_pattern | TestPattern.second_pattern | TestPattern.third_pattern ref_pattern = GraphPattern() _ = ref_pattern.add_node(label='first', type=TestPattern.first_type) _ = ref_pattern.add_node(label='second', type=TestPattern.second_type) _ = ref_pattern.add_node(label='third', type=TestPattern.third_type) assert ref_pattern == pattern pattern = (TestPattern.first_pattern + TestPattern.second_pattern) pattern_nodes = list(pattern.graph.nodes) third_nodes = list(TestPattern.third_pattern.graph.nodes) edges = list(itertools.product(pattern_nodes, third_nodes)) pattern.join_patterns(TestPattern.third_pattern, edges) ref_pattern = GraphPattern() ref_pattern.add_node(label='second', type=TestPattern.first_type) added_node = ref_pattern.add_node(label='second', type=TestPattern.second_type) for node in ref_pattern.graph.nodes: if node != added_node: ref_pattern.add_edge(node, added_node) added_node = ref_pattern.add_node(label='third', type=TestPattern.third_type) for node in ref_pattern.graph.nodes: if node != added_node: ref_pattern.add_edge(node, added_node) assert ref_pattern == pattern
def create_h_swish_act() -> GraphPattern: # TODO (vshampor): current approach with join_patterns is deficient since it does not allow to reliably # connect nodes after a pattern has been joined. Along with the label and the type, the nodes created # in the pattern must allow a "name" or "address" attribute, which must be a unique human readable # string identifier of the node even if it has been joined multiple times, or perhaps each pattern # after joining must return a list of output nodes so that these can be joined to later. # Currently cannot specify h_swish in terms of h_sigmoid due to this. main_pattern = GraphPattern() # ReLU version pattern = GraphPattern() input_pattern_node = pattern.add_node( label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE) add_node = pattern.add_node(label='ADD', type='AddV2') relu_node = pattern.add_node(label='RELU', type='ReLU') mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul') pattern.add_edge(input_pattern_node, add_node) pattern.add_edge(add_node, relu_node) pattern.add_edge(relu_node, mul_node) mul_2_node = pattern.add_node(label='MULTIPLY', type='Multiply') pattern.add_edge(input_pattern_node, mul_2_node) pattern.add_edge(mul_node, mul_2_node) main_pattern.add_pattern_alternative(pattern) # ReLU6 version pattern = GraphPattern() input_pattern_node = pattern.add_node( label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE) add_node = pattern.add_node(label='ADD', type='AddV2') relu6_node = pattern.add_node(label='RELU6', type='Relu6') mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul') pattern.add_edge(input_pattern_node, add_node) pattern.add_edge(add_node, relu6_node) pattern.add_edge(relu6_node, mul_node) mul_2_node = pattern.add_node(label='MULTIPLY', type='Multiply') pattern.add_edge(input_pattern_node, mul_2_node) pattern.add_edge(mul_node, mul_2_node) main_pattern.add_pattern_alternative(pattern) return main_pattern