def create_h_sigmoid_act() -> GraphPattern: main_pattern = GraphPattern() # ReLU version pattern = GraphPattern() input_pattern_node = pattern.add_node( label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE) add_node = pattern.add_node(label='ADD', type='AddV2') relu_node = pattern.add_node(label='RELU', type='ReLU') mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul') pattern.add_edge(input_pattern_node, add_node) pattern.add_edge(add_node, relu_node) pattern.add_edge(relu_node, mul_node) main_pattern.add_pattern_alternative(pattern) # ReLU6 version pattern = GraphPattern() input_pattern_node = pattern.add_node( label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE) add_node = pattern.add_node(label='ADD', type='AddV2') relu6_node = pattern.add_node(label='RELU6', type='Relu6') mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul') pattern.add_edge(input_pattern_node, add_node) pattern.add_edge(add_node, relu6_node) pattern.add_edge(relu6_node, mul_node) main_pattern.add_pattern_alternative(pattern) return main_pattern
def test_join_patterns_func_three_patterns(): pattern = (TestPattern.first_pattern + TestPattern.second_pattern + TestPattern.third_pattern) pattern_nodes = list(pattern.graph.nodes) third_nodes = list(TestPattern.third_pattern.graph.nodes) edges = list(itertools.product(pattern_nodes, third_nodes)) pattern.join_patterns(TestPattern.third_pattern, edges) ref_pattern = GraphPattern() _ = ref_pattern.add_node(label='first', type=TestPattern.first_type) added_node = ref_pattern.add_node(label='second', type=TestPattern.second_type) for node in ref_pattern.graph.nodes: if node != added_node: ref_pattern.add_edge(node, added_node) last_node = list(nx.topological_sort(ref_pattern.graph))[-1] added_node = ref_pattern.add_node(label='third', type=TestPattern.third_type) ref_pattern.add_edge(last_node, added_node) added_node = ref_pattern.add_node(label='third', type=TestPattern.third_type) for node in ref_pattern.graph.nodes: if node != added_node: ref_pattern.add_edge(node, added_node) assert ref_pattern == pattern
def test_ops_combination_two_patterns(): pattern = TestPattern.first_pattern + TestPattern.second_pattern ref_pattern = GraphPattern() ref_pattern.add_node(label='first', type=TestPattern.first_type) added_node = ref_pattern.add_node(label='second', type=TestPattern.second_type) for node in ref_pattern.graph.nodes: if node != added_node: ref_pattern.add_edge(node, added_node) assert ref_pattern == pattern pattern = TestPattern.first_pattern | TestPattern.second_pattern ref_pattern = GraphPattern() ref_pattern.add_node(label='first', type=TestPattern.first_type) _ = ref_pattern.add_node(label='second', type=TestPattern.second_type) assert ref_pattern == pattern
def create_h_swish_act() -> GraphPattern: # TODO (vshampor): current approach with join_patterns is deficient since it does not allow to reliably # connect nodes after a pattern has been joined. Along with the label and the type, the nodes created # in the pattern must allow a "name" or "address" attribute, which must be a unique human readable # string identifier of the node even if it has been joined multiple times, or perhaps each pattern # after joining must return a list of output nodes so that these can be joined to later. # Currently cannot specify h_swish in terms of h_sigmoid due to this. main_pattern = GraphPattern() # ReLU version pattern = GraphPattern() input_pattern_node = pattern.add_node( label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE) add_node = pattern.add_node(label='ADD', type='AddV2') relu_node = pattern.add_node(label='RELU', type='ReLU') mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul') pattern.add_edge(input_pattern_node, add_node) pattern.add_edge(add_node, relu_node) pattern.add_edge(relu_node, mul_node) mul_2_node = pattern.add_node(label='MULTIPLY', type='Multiply') pattern.add_edge(input_pattern_node, mul_2_node) pattern.add_edge(mul_node, mul_2_node) main_pattern.add_pattern_alternative(pattern) # ReLU6 version pattern = GraphPattern() input_pattern_node = pattern.add_node( label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE) add_node = pattern.add_node(label='ADD', type='AddV2') relu6_node = pattern.add_node(label='RELU6', type='Relu6') mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul') pattern.add_edge(input_pattern_node, add_node) pattern.add_edge(add_node, relu6_node) pattern.add_edge(relu6_node, mul_node) mul_2_node = pattern.add_node(label='MULTIPLY', type='Multiply') pattern.add_edge(input_pattern_node, mul_2_node) pattern.add_edge(mul_node, mul_2_node) main_pattern.add_pattern_alternative(pattern) return main_pattern
def test_join_pattern_with_special_input_node(): pattern = TestPattern.first_pattern second_pattern = GraphPattern() second_pattern.add_node(label='second', type=GraphPattern.ANY_PATTERN_NODE_TYPE) pattern.join_patterns(second_pattern) pattern.join_patterns(TestPattern.third_pattern) ref_pattern = GraphPattern() ref_pattern.add_node(label='first', type=TestPattern.first_type) added_node = ref_pattern.add_node(label='third', type=TestPattern.third_type) for node in ref_pattern.graph.nodes: if node != added_node: ref_pattern.add_edge(node, added_node) assert pattern == ref_pattern
def _get_torch_hw_fused_patterns() -> HWFusedPatterns: retval = HWFusedPatterns() linear_ops = GraphPattern() linear_ops.add_node(**LINEAR_OPERATIONS) retval.register(linear_ops, LINEAR_OPERATIONS['label'], match=False) matmul_ops = GraphPattern() matmul_ops.add_node(**MATMUL_OPERATIONS) retval.register(linear_ops, MATMUL_OPERATIONS['label'], match=False) batch_norm = GraphPattern() batch_norm.add_node(**BATCH_NORMALIZATION_OPERATIONS) retval.register(batch_norm, BATCH_NORMALIZATION_OPERATIONS['label'], match=False) atomic_activations = GraphPattern() atomic_activations.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS) swish = create_swish_act() h_sigmoid = create_h_sigmoid_act() h_swish = create_h_swish_act() activations = atomic_activations | swish | h_swish | h_sigmoid retval.register(activations, 'ACTIVATIONS', match=False) arithmetic_ops = GraphPattern() arithmetic_ops.add_node(**ARITHMETIC_OPERATIONS) retval.register(arithmetic_ops, ARITHMETIC_OPERATIONS['label'], match=False) batch_norm_activations_permutation = batch_norm + activations | activations + batch_norm | batch_norm | activations retval.register(linear_ops + batch_norm_activations_permutation, 'LINEAR + BN_ACT_PERM', match=True) retval.register(matmul_ops + arithmetic_ops, 'MATMUL + ARITHMETIC', match=True) retval.register(batch_norm + activations, 'BN + ACTIVATIONS', match=True) retval.register(activations + batch_norm, 'ACTIVATIONS + BN', match=True) retval.register(arithmetic_ops + batch_norm_activations_permutation, 'ARITHMETIC + BN_ACT_PERM', match=True) group_norm = GraphPattern() group_norm.add_node(**GROUP_NORMALIZATION_OPERATIONS) relu = GraphPattern() relu.add_node(**RELU_OPERATIONS) retval.register(group_norm + relu, 'GROUP_NORM + RELU', match=True) l2_norm = create_l2_norm() retval.register(l2_norm, 'L2_NORM', match=True) return retval
def test_ops_combination_three_patterns(): pattern = TestPattern.first_pattern + TestPattern.second_pattern | TestPattern.third_pattern ref_pattern = GraphPattern() ref_pattern.add_node(label='first', type=TestPattern.first_type) added_node = ref_pattern.add_node(label='second', type=TestPattern.second_type) for node in ref_pattern.graph.nodes: if node != added_node: ref_pattern.add_edge(node, added_node) _ = ref_pattern.add_node(label='third', type=TestPattern.third_type) assert ref_pattern == pattern pattern = TestPattern.first_pattern | TestPattern.second_pattern | TestPattern.third_pattern ref_pattern = GraphPattern() _ = ref_pattern.add_node(label='first', type=TestPattern.first_type) _ = ref_pattern.add_node(label='second', type=TestPattern.second_type) _ = ref_pattern.add_node(label='third', type=TestPattern.third_type) assert ref_pattern == pattern pattern = (TestPattern.first_pattern + TestPattern.second_pattern) pattern_nodes = list(pattern.graph.nodes) third_nodes = list(TestPattern.third_pattern.graph.nodes) edges = list(itertools.product(pattern_nodes, third_nodes)) pattern.join_patterns(TestPattern.third_pattern, edges) ref_pattern = GraphPattern() ref_pattern.add_node(label='second', type=TestPattern.first_type) added_node = ref_pattern.add_node(label='second', type=TestPattern.second_type) for node in ref_pattern.graph.nodes: if node != added_node: ref_pattern.add_edge(node, added_node) added_node = ref_pattern.add_node(label='third', type=TestPattern.third_type) for node in ref_pattern.graph.nodes: if node != added_node: ref_pattern.add_edge(node, added_node) assert ref_pattern == pattern
def test_join_patterns_func(): ref_pattern = GraphPattern() ref_pattern.add_node(label='first', type=TestPattern.first_type) added_node = ref_pattern.add_node(label='second', type=TestPattern.second_type) for node in ref_pattern.graph.nodes: if node != added_node: ref_pattern.add_edge(node, added_node) first_nodes = list(TestPattern.first_pattern.graph.nodes) second_nodes = list(TestPattern.second_pattern.graph.nodes) edges = list(itertools.product(first_nodes, second_nodes)) pattern = copy.copy(TestPattern.first_pattern) pattern.join_patterns(TestPattern.second_pattern, edges) assert ref_pattern == pattern
def _get_tf_hw_fused_patterns() -> HWFusedPatterns: retval = HWFusedPatterns() linear_ops = GraphPattern() linear_ops.add_node(**LINEAR_OPERATIONS) eltwise_ops = GraphPattern() eltwise_ops.add_node(**ELEMENTWISE_OPERATIONS) batch_norm = GraphPattern() batch_norm.add_node(**BATCH_NORMALIZATION_OPERATIONS) h_sigmoid = create_h_sigmoid_act() h_swish = create_h_swish_act() retval.register(h_sigmoid, 'H_SIGMOID', match=True) retval.register(h_swish, 'H_SWISH', match=True) atomic_activations = GraphPattern() atomic_activations.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS) activations = atomic_activations | h_swish | h_sigmoid batch_norm_activations_permutation = batch_norm + activations | activations + batch_norm any_bn_act_combo = batch_norm | activations | batch_norm_activations_permutation identity = GraphPattern() identity.add_node(type=['Identity'], label='IDENTITY') linear_ops_maybe_followed_by_identity = linear_ops | (linear_ops + identity) agnostic_ops = GraphPattern() agnostic_ops.add_node(**QUANTIZATION_AGNOSTIC_OPERATIONS) any_ag_bn_act_combo = agnostic_ops + activations | any_bn_act_combo retval.register(linear_ops_maybe_followed_by_identity, name='LINEAR', match=True) retval.register(batch_norm_activations_permutation, name='BN_ACT_OR_ACT_BN', match=True) retval.register(linear_ops_maybe_followed_by_identity + any_ag_bn_act_combo, 'LINEAR + ANY_AG_BN_ACT_COMBO', match=True) retval.register(eltwise_ops + any_ag_bn_act_combo, 'ELTWISE + ANY_AG_BN_ACT_COMBO', match=True) return retval
class TestPattern: first_type = ['a', 'b'] second_type = ['c', 'd'] third_type = ['e'] forth_type = [GraphPattern.NON_PATTERN_NODE_TYPE] fifth_type = [GraphPattern.ANY_PATTERN_NODE_TYPE] first_pattern = GraphPattern() first_pattern.add_node(label='first', type=first_type) second_pattern = GraphPattern() second_pattern.add_node(label='second', type=second_type) third_pattern = GraphPattern() third_pattern.add_node(label='third', type=third_type) forth_pattern = GraphPattern() forth_pattern.add_node(label='forth', type=forth_type) fifth_pattern = GraphPattern() fifth_pattern.add_node(label='fifth', type=fifth_pattern) # pattern_with_non_pattern_nodes | pattern_with_any_pattern_nodes # NON | ANY # | | | # 1 | 1 # | | | # 2 NON | 2 ANY # / \ / | / \ / # 4 3 | 4 3 # | / | | / # | / | | / # |/ | |/ # 5 | 5 # | | | # 6---NON | 6---ANY pattern_with_non_pattern_nodes = GraphPattern() pattern_with_any_pattern_nodes = GraphPattern() common_nodes = { '1': { 'type': 'a' }, '2': { 'type': 'b' }, '3': { 'type': 'c' }, '4': { 'type': 'a' }, '5': { 'type': 'e' }, '6': { 'type': 'a' } } non_pattern_nodes = { '7': { 'type': GraphPattern.NON_PATTERN_NODE_TYPE }, '8': { 'type': GraphPattern.NON_PATTERN_NODE_TYPE }, '9': { 'type': GraphPattern.NON_PATTERN_NODE_TYPE } } any_pattern_nodes = { '7': { 'type': GraphPattern.ANY_PATTERN_NODE_TYPE }, '8': { 'type': GraphPattern.ANY_PATTERN_NODE_TYPE }, '9': { 'type': GraphPattern.ANY_PATTERN_NODE_TYPE } } label_to_non_pattern_nodes = {} label_to_any_pattern_nodes = {} for label, attrs in common_nodes.items(): label_to_non_pattern_nodes[ label] = pattern_with_non_pattern_nodes.add_node(label=label, **attrs) label_to_any_pattern_nodes[ label] = pattern_with_any_pattern_nodes.add_node(label=label, **attrs) for label, attrs in non_pattern_nodes.items(): label_to_non_pattern_nodes[ label] = pattern_with_non_pattern_nodes.add_node(label=label, **attrs) for label, attrs in any_pattern_nodes.items(): label_to_any_pattern_nodes[ label] = pattern_with_any_pattern_nodes.add_node(label=label, **attrs) edges = [('1', '2'), ('2', '3'), ('2', '4'), ('4', '5'), ('5', '6'), ('3', '5'), ('7', '1'), ('8', '3'), ('9', '6')] for edge in edges: pattern_with_non_pattern_nodes.add_edge( label_to_non_pattern_nodes[edge[0]], label_to_non_pattern_nodes[edge[1]]) pattern_with_any_pattern_nodes.add_edge( label_to_any_pattern_nodes[edge[0]], label_to_any_pattern_nodes[edge[1]])