예제 #1
0
def test_join_patterns_func_three_patterns():
    pattern = (TestPattern.first_pattern + TestPattern.second_pattern +
               TestPattern.third_pattern)
    pattern_nodes = list(pattern.graph.nodes)
    third_nodes = list(TestPattern.third_pattern.graph.nodes)
    edges = list(itertools.product(pattern_nodes, third_nodes))
    pattern.join_patterns(TestPattern.third_pattern, edges)
    ref_pattern = GraphPattern()
    _ = ref_pattern.add_node(label='first', type=TestPattern.first_type)
    added_node = ref_pattern.add_node(label='second',
                                      type=TestPattern.second_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)
    last_node = list(nx.topological_sort(ref_pattern.graph))[-1]
    added_node = ref_pattern.add_node(label='third',
                                      type=TestPattern.third_type)
    ref_pattern.add_edge(last_node, added_node)

    added_node = ref_pattern.add_node(label='third',
                                      type=TestPattern.third_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)
    assert ref_pattern == pattern
예제 #2
0
def create_h_sigmoid_act() -> GraphPattern:
    main_pattern = GraphPattern()

    # ReLU version
    pattern = GraphPattern()

    input_pattern_node = pattern.add_node(
        label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE)
    add_node = pattern.add_node(label='ADD', type='AddV2')
    relu_node = pattern.add_node(label='RELU', type='ReLU')
    mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul')

    pattern.add_edge(input_pattern_node, add_node)
    pattern.add_edge(add_node, relu_node)
    pattern.add_edge(relu_node, mul_node)

    main_pattern.add_pattern_alternative(pattern)

    # ReLU6 version

    pattern = GraphPattern()

    input_pattern_node = pattern.add_node(
        label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE)
    add_node = pattern.add_node(label='ADD', type='AddV2')
    relu6_node = pattern.add_node(label='RELU6', type='Relu6')
    mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul')

    pattern.add_edge(input_pattern_node, add_node)
    pattern.add_edge(add_node, relu6_node)
    pattern.add_edge(relu6_node, mul_node)

    main_pattern.add_pattern_alternative(pattern)

    return main_pattern
예제 #3
0
def test_join_patterns_func():
    ref_pattern = GraphPattern()
    ref_pattern.add_node(label='first', type=TestPattern.first_type)
    added_node = ref_pattern.add_node(label='second',
                                      type=TestPattern.second_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)

    first_nodes = list(TestPattern.first_pattern.graph.nodes)
    second_nodes = list(TestPattern.second_pattern.graph.nodes)
    edges = list(itertools.product(first_nodes, second_nodes))
    pattern = copy.copy(TestPattern.first_pattern)
    pattern.join_patterns(TestPattern.second_pattern, edges)
    assert ref_pattern == pattern
예제 #4
0
def create_h_swish_act() -> GraphPattern:
    # TODO (vshampor): current approach with join_patterns is deficient since it does not allow to reliably
    #  connect nodes after a pattern has been joined. Along with the label and the type, the nodes created
    #  in the pattern must allow a "name" or "address" attribute, which must be a unique human readable
    #  string identifier of the node even if it has been joined multiple times, or perhaps each pattern
    #  after joining must return a list of output nodes so that these can be joined to later.
    #  Currently cannot specify h_swish in terms of h_sigmoid due to this.

    main_pattern = GraphPattern()

    # ReLU version
    pattern = GraphPattern()
    input_pattern_node = pattern.add_node(
        label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE)

    add_node = pattern.add_node(label='ADD', type='AddV2')
    relu_node = pattern.add_node(label='RELU', type='ReLU')
    mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul')

    pattern.add_edge(input_pattern_node, add_node)
    pattern.add_edge(add_node, relu_node)
    pattern.add_edge(relu_node, mul_node)

    mul_2_node = pattern.add_node(label='MULTIPLY', type='Multiply')
    pattern.add_edge(input_pattern_node, mul_2_node)
    pattern.add_edge(mul_node, mul_2_node)
    main_pattern.add_pattern_alternative(pattern)

    # ReLU6 version
    pattern = GraphPattern()
    input_pattern_node = pattern.add_node(
        label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE)

    add_node = pattern.add_node(label='ADD', type='AddV2')
    relu6_node = pattern.add_node(label='RELU6', type='Relu6')
    mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul')

    pattern.add_edge(input_pattern_node, add_node)
    pattern.add_edge(add_node, relu6_node)
    pattern.add_edge(relu6_node, mul_node)

    mul_2_node = pattern.add_node(label='MULTIPLY', type='Multiply')
    pattern.add_edge(input_pattern_node, mul_2_node)
    pattern.add_edge(mul_node, mul_2_node)
    main_pattern.add_pattern_alternative(pattern)

    return main_pattern
예제 #5
0
def test_join_pattern_with_special_input_node():
    pattern = TestPattern.first_pattern
    second_pattern = GraphPattern()
    second_pattern.add_node(label='second',
                            type=GraphPattern.ANY_PATTERN_NODE_TYPE)
    pattern.join_patterns(second_pattern)
    pattern.join_patterns(TestPattern.third_pattern)

    ref_pattern = GraphPattern()
    ref_pattern.add_node(label='first', type=TestPattern.first_type)
    added_node = ref_pattern.add_node(label='third',
                                      type=TestPattern.third_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)

    assert pattern == ref_pattern
예제 #6
0
def test_ops_combination_two_patterns():
    pattern = TestPattern.first_pattern + TestPattern.second_pattern
    ref_pattern = GraphPattern()
    ref_pattern.add_node(label='first', type=TestPattern.first_type)
    added_node = ref_pattern.add_node(label='second',
                                      type=TestPattern.second_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)
    assert ref_pattern == pattern

    pattern = TestPattern.first_pattern | TestPattern.second_pattern
    ref_pattern = GraphPattern()
    ref_pattern.add_node(label='first', type=TestPattern.first_type)
    _ = ref_pattern.add_node(label='second', type=TestPattern.second_type)
    assert ref_pattern == pattern
예제 #7
0
def test_ops_combination_three_patterns():
    pattern = TestPattern.first_pattern + TestPattern.second_pattern | TestPattern.third_pattern
    ref_pattern = GraphPattern()
    ref_pattern.add_node(label='first', type=TestPattern.first_type)
    added_node = ref_pattern.add_node(label='second',
                                      type=TestPattern.second_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)
    _ = ref_pattern.add_node(label='third', type=TestPattern.third_type)
    assert ref_pattern == pattern

    pattern = TestPattern.first_pattern | TestPattern.second_pattern | TestPattern.third_pattern
    ref_pattern = GraphPattern()
    _ = ref_pattern.add_node(label='first', type=TestPattern.first_type)
    _ = ref_pattern.add_node(label='second', type=TestPattern.second_type)
    _ = ref_pattern.add_node(label='third', type=TestPattern.third_type)
    assert ref_pattern == pattern

    pattern = (TestPattern.first_pattern + TestPattern.second_pattern)
    pattern_nodes = list(pattern.graph.nodes)
    third_nodes = list(TestPattern.third_pattern.graph.nodes)
    edges = list(itertools.product(pattern_nodes, third_nodes))
    pattern.join_patterns(TestPattern.third_pattern, edges)
    ref_pattern = GraphPattern()
    ref_pattern.add_node(label='second', type=TestPattern.first_type)
    added_node = ref_pattern.add_node(label='second',
                                      type=TestPattern.second_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)
    added_node = ref_pattern.add_node(label='third',
                                      type=TestPattern.third_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)
    assert ref_pattern == pattern
예제 #8
0
def _get_tf_hw_fused_patterns() -> HWFusedPatterns:
    retval = HWFusedPatterns()
    linear_ops = GraphPattern()
    linear_ops.add_node(**LINEAR_OPERATIONS)

    eltwise_ops = GraphPattern()
    eltwise_ops.add_node(**ELEMENTWISE_OPERATIONS)

    batch_norm = GraphPattern()
    batch_norm.add_node(**BATCH_NORMALIZATION_OPERATIONS)

    h_sigmoid = create_h_sigmoid_act()
    h_swish = create_h_swish_act()
    retval.register(h_sigmoid, 'H_SIGMOID', match=True)
    retval.register(h_swish, 'H_SWISH', match=True)

    atomic_activations = GraphPattern()
    atomic_activations.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS)
    activations = atomic_activations | h_swish | h_sigmoid
    batch_norm_activations_permutation = batch_norm + activations | activations + batch_norm
    any_bn_act_combo = batch_norm | activations | batch_norm_activations_permutation

    identity = GraphPattern()
    identity.add_node(type=['Identity'], label='IDENTITY')
    linear_ops_maybe_followed_by_identity = linear_ops | (linear_ops + identity)

    agnostic_ops = GraphPattern()
    agnostic_ops.add_node(**QUANTIZATION_AGNOSTIC_OPERATIONS)
    any_ag_bn_act_combo = agnostic_ops + activations | any_bn_act_combo

    retval.register(linear_ops_maybe_followed_by_identity, name='LINEAR', match=True)
    retval.register(batch_norm_activations_permutation, name='BN_ACT_OR_ACT_BN', match=True)
    retval.register(linear_ops_maybe_followed_by_identity + any_ag_bn_act_combo, 'LINEAR + ANY_AG_BN_ACT_COMBO',
                    match=True)
    retval.register(eltwise_ops + any_ag_bn_act_combo, 'ELTWISE + ANY_AG_BN_ACT_COMBO',
                    match=True)
    return retval
예제 #9
0
def _get_torch_hw_fused_patterns() -> HWFusedPatterns:
    retval = HWFusedPatterns()
    linear_ops = GraphPattern()
    linear_ops.add_node(**LINEAR_OPERATIONS)
    retval.register(linear_ops, LINEAR_OPERATIONS['label'], match=False)

    matmul_ops = GraphPattern()
    matmul_ops.add_node(**MATMUL_OPERATIONS)
    retval.register(linear_ops, MATMUL_OPERATIONS['label'], match=False)

    batch_norm = GraphPattern()
    batch_norm.add_node(**BATCH_NORMALIZATION_OPERATIONS)
    retval.register(batch_norm,
                    BATCH_NORMALIZATION_OPERATIONS['label'],
                    match=False)

    atomic_activations = GraphPattern()
    atomic_activations.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS)
    swish = create_swish_act()
    h_sigmoid = create_h_sigmoid_act()
    h_swish = create_h_swish_act()
    activations = atomic_activations | swish | h_swish | h_sigmoid
    retval.register(activations, 'ACTIVATIONS', match=False)

    arithmetic_ops = GraphPattern()
    arithmetic_ops.add_node(**ARITHMETIC_OPERATIONS)
    retval.register(arithmetic_ops,
                    ARITHMETIC_OPERATIONS['label'],
                    match=False)

    batch_norm_activations_permutation = batch_norm + activations | activations + batch_norm | batch_norm | activations

    retval.register(linear_ops + batch_norm_activations_permutation,
                    'LINEAR + BN_ACT_PERM',
                    match=True)
    retval.register(matmul_ops + arithmetic_ops,
                    'MATMUL + ARITHMETIC',
                    match=True)
    retval.register(batch_norm + activations, 'BN + ACTIVATIONS', match=True)
    retval.register(activations + batch_norm, 'ACTIVATIONS + BN', match=True)
    retval.register(arithmetic_ops + batch_norm_activations_permutation,
                    'ARITHMETIC + BN_ACT_PERM',
                    match=True)

    group_norm = GraphPattern()
    group_norm.add_node(**GROUP_NORMALIZATION_OPERATIONS)
    relu = GraphPattern()
    relu.add_node(**RELU_OPERATIONS)
    retval.register(group_norm + relu, 'GROUP_NORM + RELU', match=True)

    l2_norm = create_l2_norm()
    retval.register(l2_norm, 'L2_NORM', match=True)
    return retval