Esempio n. 1
0
def test_join_patterns_func():
    ref_pattern = GraphPattern()
    ref_pattern.add_node(label='first', type=TestPattern.first_type)
    added_node = ref_pattern.add_node(label='second',
                                      type=TestPattern.second_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)

    first_nodes = list(TestPattern.first_pattern.graph.nodes)
    second_nodes = list(TestPattern.second_pattern.graph.nodes)
    edges = list(itertools.product(first_nodes, second_nodes))
    pattern = copy.copy(TestPattern.first_pattern)
    pattern.join_patterns(TestPattern.second_pattern, edges)
    assert ref_pattern == pattern
Esempio n. 2
0
def create_h_swish_act() -> GraphPattern:
    # TODO (vshampor): current approach with join_patterns is deficient since it does not allow to reliably
    #  connect nodes after a pattern has been joined. Along with the label and the type, the nodes created
    #  in the pattern must allow a "name" or "address" attribute, which must be a unique human readable
    #  string identifier of the node even if it has been joined multiple times, or perhaps each pattern
    #  after joining must return a list of output nodes so that these can be joined to later.
    #  Currently cannot specify h_swish in terms of h_sigmoid due to this.

    main_pattern = GraphPattern()

    # ReLU version
    pattern = GraphPattern()
    input_pattern_node = pattern.add_node(
        label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE)

    add_node = pattern.add_node(label='ADD', type='AddV2')
    relu_node = pattern.add_node(label='RELU', type='ReLU')
    mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul')

    pattern.add_edge(input_pattern_node, add_node)
    pattern.add_edge(add_node, relu_node)
    pattern.add_edge(relu_node, mul_node)

    mul_2_node = pattern.add_node(label='MULTIPLY', type='Multiply')
    pattern.add_edge(input_pattern_node, mul_2_node)
    pattern.add_edge(mul_node, mul_2_node)
    main_pattern.add_pattern_alternative(pattern)

    # ReLU6 version
    pattern = GraphPattern()
    input_pattern_node = pattern.add_node(
        label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE)

    add_node = pattern.add_node(label='ADD', type='AddV2')
    relu6_node = pattern.add_node(label='RELU6', type='Relu6')
    mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul')

    pattern.add_edge(input_pattern_node, add_node)
    pattern.add_edge(add_node, relu6_node)
    pattern.add_edge(relu6_node, mul_node)

    mul_2_node = pattern.add_node(label='MULTIPLY', type='Multiply')
    pattern.add_edge(input_pattern_node, mul_2_node)
    pattern.add_edge(mul_node, mul_2_node)
    main_pattern.add_pattern_alternative(pattern)

    return main_pattern
Esempio n. 3
0
def create_h_sigmoid_act() -> GraphPattern:
    main_pattern = GraphPattern()

    # ReLU version
    pattern = GraphPattern()

    input_pattern_node = pattern.add_node(
        label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE)
    add_node = pattern.add_node(label='ADD', type='AddV2')
    relu_node = pattern.add_node(label='RELU', type='ReLU')
    mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul')

    pattern.add_edge(input_pattern_node, add_node)
    pattern.add_edge(add_node, relu_node)
    pattern.add_edge(relu_node, mul_node)

    main_pattern.add_pattern_alternative(pattern)

    # ReLU6 version

    pattern = GraphPattern()

    input_pattern_node = pattern.add_node(
        label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE)
    add_node = pattern.add_node(label='ADD', type='AddV2')
    relu6_node = pattern.add_node(label='RELU6', type='Relu6')
    mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul')

    pattern.add_edge(input_pattern_node, add_node)
    pattern.add_edge(add_node, relu6_node)
    pattern.add_edge(relu6_node, mul_node)

    main_pattern.add_pattern_alternative(pattern)

    return main_pattern
Esempio n. 4
0
class TestPattern:
    first_type = ['a', 'b']
    second_type = ['c', 'd']
    third_type = ['e']
    forth_type = [GraphPattern.NON_PATTERN_NODE_TYPE]
    fifth_type = [GraphPattern.ANY_PATTERN_NODE_TYPE]

    first_pattern = GraphPattern()
    first_pattern.add_node(label='first', type=first_type)
    second_pattern = GraphPattern()
    second_pattern.add_node(label='second', type=second_type)
    third_pattern = GraphPattern()
    third_pattern.add_node(label='third', type=third_type)
    forth_pattern = GraphPattern()
    forth_pattern.add_node(label='forth', type=forth_type)
    fifth_pattern = GraphPattern()
    fifth_pattern.add_node(label='fifth', type=fifth_pattern)

    # pattern_with_non_pattern_nodes |  pattern_with_any_pattern_nodes
    #        NON                     |            ANY
    #         |                      |             |
    #         1                      |             1
    #         |                      |             |
    #         2  NON                 |             2  ANY
    #        / \ /                   |            / \ /
    #       4   3                    |           4   3
    #       |  /                     |           |  /
    #       | /                      |           | /
    #       |/                       |           |/
    #       5                        |           5
    #       |                        |           |
    #       6---NON                  |           6---ANY

    pattern_with_non_pattern_nodes = GraphPattern()
    pattern_with_any_pattern_nodes = GraphPattern()
    common_nodes = {
        '1': {
            'type': 'a'
        },
        '2': {
            'type': 'b'
        },
        '3': {
            'type': 'c'
        },
        '4': {
            'type': 'a'
        },
        '5': {
            'type': 'e'
        },
        '6': {
            'type': 'a'
        }
    }
    non_pattern_nodes = {
        '7': {
            'type': GraphPattern.NON_PATTERN_NODE_TYPE
        },
        '8': {
            'type': GraphPattern.NON_PATTERN_NODE_TYPE
        },
        '9': {
            'type': GraphPattern.NON_PATTERN_NODE_TYPE
        }
    }
    any_pattern_nodes = {
        '7': {
            'type': GraphPattern.ANY_PATTERN_NODE_TYPE
        },
        '8': {
            'type': GraphPattern.ANY_PATTERN_NODE_TYPE
        },
        '9': {
            'type': GraphPattern.ANY_PATTERN_NODE_TYPE
        }
    }
    label_to_non_pattern_nodes = {}
    label_to_any_pattern_nodes = {}
    for label, attrs in common_nodes.items():
        label_to_non_pattern_nodes[
            label] = pattern_with_non_pattern_nodes.add_node(label=label,
                                                             **attrs)
        label_to_any_pattern_nodes[
            label] = pattern_with_any_pattern_nodes.add_node(label=label,
                                                             **attrs)
    for label, attrs in non_pattern_nodes.items():
        label_to_non_pattern_nodes[
            label] = pattern_with_non_pattern_nodes.add_node(label=label,
                                                             **attrs)
    for label, attrs in any_pattern_nodes.items():
        label_to_any_pattern_nodes[
            label] = pattern_with_any_pattern_nodes.add_node(label=label,
                                                             **attrs)

    edges = [('1', '2'), ('2', '3'), ('2', '4'), ('4', '5'), ('5', '6'),
             ('3', '5'), ('7', '1'), ('8', '3'), ('9', '6')]
    for edge in edges:
        pattern_with_non_pattern_nodes.add_edge(
            label_to_non_pattern_nodes[edge[0]],
            label_to_non_pattern_nodes[edge[1]])
        pattern_with_any_pattern_nodes.add_edge(
            label_to_any_pattern_nodes[edge[0]],
            label_to_any_pattern_nodes[edge[1]])
Esempio n. 5
0
def test_join_pattern_with_special_input_node():
    pattern = TestPattern.first_pattern
    second_pattern = GraphPattern()
    second_pattern.add_node(label='second',
                            type=GraphPattern.ANY_PATTERN_NODE_TYPE)
    pattern.join_patterns(second_pattern)
    pattern.join_patterns(TestPattern.third_pattern)

    ref_pattern = GraphPattern()
    ref_pattern.add_node(label='first', type=TestPattern.first_type)
    added_node = ref_pattern.add_node(label='third',
                                      type=TestPattern.third_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)

    assert pattern == ref_pattern
Esempio n. 6
0
def test_join_patterns_func_three_patterns():
    pattern = (TestPattern.first_pattern + TestPattern.second_pattern +
               TestPattern.third_pattern)
    pattern_nodes = list(pattern.graph.nodes)
    third_nodes = list(TestPattern.third_pattern.graph.nodes)
    edges = list(itertools.product(pattern_nodes, third_nodes))
    pattern.join_patterns(TestPattern.third_pattern, edges)
    ref_pattern = GraphPattern()
    _ = ref_pattern.add_node(label='first', type=TestPattern.first_type)
    added_node = ref_pattern.add_node(label='second',
                                      type=TestPattern.second_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)
    last_node = list(nx.topological_sort(ref_pattern.graph))[-1]
    added_node = ref_pattern.add_node(label='third',
                                      type=TestPattern.third_type)
    ref_pattern.add_edge(last_node, added_node)

    added_node = ref_pattern.add_node(label='third',
                                      type=TestPattern.third_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)
    assert ref_pattern == pattern
Esempio n. 7
0
def test_ops_combination_three_patterns():
    pattern = TestPattern.first_pattern + TestPattern.second_pattern | TestPattern.third_pattern
    ref_pattern = GraphPattern()
    ref_pattern.add_node(label='first', type=TestPattern.first_type)
    added_node = ref_pattern.add_node(label='second',
                                      type=TestPattern.second_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)
    _ = ref_pattern.add_node(label='third', type=TestPattern.third_type)
    assert ref_pattern == pattern

    pattern = TestPattern.first_pattern | TestPattern.second_pattern | TestPattern.third_pattern
    ref_pattern = GraphPattern()
    _ = ref_pattern.add_node(label='first', type=TestPattern.first_type)
    _ = ref_pattern.add_node(label='second', type=TestPattern.second_type)
    _ = ref_pattern.add_node(label='third', type=TestPattern.third_type)
    assert ref_pattern == pattern

    pattern = (TestPattern.first_pattern + TestPattern.second_pattern)
    pattern_nodes = list(pattern.graph.nodes)
    third_nodes = list(TestPattern.third_pattern.graph.nodes)
    edges = list(itertools.product(pattern_nodes, third_nodes))
    pattern.join_patterns(TestPattern.third_pattern, edges)
    ref_pattern = GraphPattern()
    ref_pattern.add_node(label='second', type=TestPattern.first_type)
    added_node = ref_pattern.add_node(label='second',
                                      type=TestPattern.second_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)
    added_node = ref_pattern.add_node(label='third',
                                      type=TestPattern.third_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)
    assert ref_pattern == pattern
Esempio n. 8
0
def test_ops_combination_two_patterns():
    pattern = TestPattern.first_pattern + TestPattern.second_pattern
    ref_pattern = GraphPattern()
    ref_pattern.add_node(label='first', type=TestPattern.first_type)
    added_node = ref_pattern.add_node(label='second',
                                      type=TestPattern.second_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)
    assert ref_pattern == pattern

    pattern = TestPattern.first_pattern | TestPattern.second_pattern
    ref_pattern = GraphPattern()
    ref_pattern.add_node(label='first', type=TestPattern.first_type)
    _ = ref_pattern.add_node(label='second', type=TestPattern.second_type)
    assert ref_pattern == pattern
Esempio n. 9
0
def _get_tf_hw_fused_patterns() -> HWFusedPatterns:
    retval = HWFusedPatterns()
    linear_ops = GraphPattern()
    linear_ops.add_node(**LINEAR_OPERATIONS)

    eltwise_ops = GraphPattern()
    eltwise_ops.add_node(**ELEMENTWISE_OPERATIONS)

    batch_norm = GraphPattern()
    batch_norm.add_node(**BATCH_NORMALIZATION_OPERATIONS)

    h_sigmoid = create_h_sigmoid_act()
    h_swish = create_h_swish_act()
    retval.register(h_sigmoid, 'H_SIGMOID', match=True)
    retval.register(h_swish, 'H_SWISH', match=True)

    atomic_activations = GraphPattern()
    atomic_activations.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS)
    activations = atomic_activations | h_swish | h_sigmoid
    batch_norm_activations_permutation = batch_norm + activations | activations + batch_norm
    any_bn_act_combo = batch_norm | activations | batch_norm_activations_permutation

    identity = GraphPattern()
    identity.add_node(type=['Identity'], label='IDENTITY')
    linear_ops_maybe_followed_by_identity = linear_ops | (linear_ops + identity)

    agnostic_ops = GraphPattern()
    agnostic_ops.add_node(**QUANTIZATION_AGNOSTIC_OPERATIONS)
    any_ag_bn_act_combo = agnostic_ops + activations | any_bn_act_combo

    retval.register(linear_ops_maybe_followed_by_identity, name='LINEAR', match=True)
    retval.register(batch_norm_activations_permutation, name='BN_ACT_OR_ACT_BN', match=True)
    retval.register(linear_ops_maybe_followed_by_identity + any_ag_bn_act_combo, 'LINEAR + ANY_AG_BN_ACT_COMBO',
                    match=True)
    retval.register(eltwise_ops + any_ag_bn_act_combo, 'ELTWISE + ANY_AG_BN_ACT_COMBO',
                    match=True)
    return retval
Esempio n. 10
0
def _get_torch_hw_fused_patterns() -> HWFusedPatterns:
    retval = HWFusedPatterns()
    linear_ops = GraphPattern()
    linear_ops.add_node(**LINEAR_OPERATIONS)
    retval.register(linear_ops, LINEAR_OPERATIONS['label'], match=False)

    matmul_ops = GraphPattern()
    matmul_ops.add_node(**MATMUL_OPERATIONS)
    retval.register(linear_ops, MATMUL_OPERATIONS['label'], match=False)

    batch_norm = GraphPattern()
    batch_norm.add_node(**BATCH_NORMALIZATION_OPERATIONS)
    retval.register(batch_norm,
                    BATCH_NORMALIZATION_OPERATIONS['label'],
                    match=False)

    atomic_activations = GraphPattern()
    atomic_activations.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS)
    swish = create_swish_act()
    h_sigmoid = create_h_sigmoid_act()
    h_swish = create_h_swish_act()
    activations = atomic_activations | swish | h_swish | h_sigmoid
    retval.register(activations, 'ACTIVATIONS', match=False)

    arithmetic_ops = GraphPattern()
    arithmetic_ops.add_node(**ARITHMETIC_OPERATIONS)
    retval.register(arithmetic_ops,
                    ARITHMETIC_OPERATIONS['label'],
                    match=False)

    batch_norm_activations_permutation = batch_norm + activations | activations + batch_norm | batch_norm | activations

    retval.register(linear_ops + batch_norm_activations_permutation,
                    'LINEAR + BN_ACT_PERM',
                    match=True)
    retval.register(matmul_ops + arithmetic_ops,
                    'MATMUL + ARITHMETIC',
                    match=True)
    retval.register(batch_norm + activations, 'BN + ACTIVATIONS', match=True)
    retval.register(activations + batch_norm, 'ACTIVATIONS + BN', match=True)
    retval.register(arithmetic_ops + batch_norm_activations_permutation,
                    'ARITHMETIC + BN_ACT_PERM',
                    match=True)

    group_norm = GraphPattern()
    group_norm.add_node(**GROUP_NORMALIZATION_OPERATIONS)
    relu = GraphPattern()
    relu.add_node(**RELU_OPERATIONS)
    retval.register(group_norm + relu, 'GROUP_NORM + RELU', match=True)

    l2_norm = create_l2_norm()
    retval.register(l2_norm, 'L2_NORM', match=True)
    return retval