Пример #1
0
def create_h_sigmoid_act() -> GraphPattern:
    main_pattern = GraphPattern()

    # ReLU version
    pattern = GraphPattern()

    input_pattern_node = pattern.add_node(
        label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE)
    add_node = pattern.add_node(label='ADD', type='AddV2')
    relu_node = pattern.add_node(label='RELU', type='ReLU')
    mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul')

    pattern.add_edge(input_pattern_node, add_node)
    pattern.add_edge(add_node, relu_node)
    pattern.add_edge(relu_node, mul_node)

    main_pattern.add_pattern_alternative(pattern)

    # ReLU6 version

    pattern = GraphPattern()

    input_pattern_node = pattern.add_node(
        label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE)
    add_node = pattern.add_node(label='ADD', type='AddV2')
    relu6_node = pattern.add_node(label='RELU6', type='Relu6')
    mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul')

    pattern.add_edge(input_pattern_node, add_node)
    pattern.add_edge(add_node, relu6_node)
    pattern.add_edge(relu6_node, mul_node)

    main_pattern.add_pattern_alternative(pattern)

    return main_pattern
Пример #2
0
def test_join_patterns_func():
    ref_pattern = GraphPattern()
    ref_pattern.add_node(label='first', type=TestPattern.first_type)
    added_node = ref_pattern.add_node(label='second',
                                      type=TestPattern.second_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)

    first_nodes = list(TestPattern.first_pattern.graph.nodes)
    second_nodes = list(TestPattern.second_pattern.graph.nodes)
    edges = list(itertools.product(first_nodes, second_nodes))
    pattern = copy.copy(TestPattern.first_pattern)
    pattern.join_patterns(TestPattern.second_pattern, edges)
    assert ref_pattern == pattern
Пример #3
0
def test_ops_combination_two_patterns():
    pattern = TestPattern.first_pattern + TestPattern.second_pattern
    ref_pattern = GraphPattern()
    ref_pattern.add_node(label='first', type=TestPattern.first_type)
    added_node = ref_pattern.add_node(label='second',
                                      type=TestPattern.second_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)
    assert ref_pattern == pattern

    pattern = TestPattern.first_pattern | TestPattern.second_pattern
    ref_pattern = GraphPattern()
    ref_pattern.add_node(label='first', type=TestPattern.first_type)
    _ = ref_pattern.add_node(label='second', type=TestPattern.second_type)
    assert ref_pattern == pattern
Пример #4
0
def test_join_pattern_with_special_input_node():
    pattern = TestPattern.first_pattern
    second_pattern = GraphPattern()
    second_pattern.add_node(label='second',
                            type=GraphPattern.ANY_PATTERN_NODE_TYPE)
    pattern.join_patterns(second_pattern)
    pattern.join_patterns(TestPattern.third_pattern)

    ref_pattern = GraphPattern()
    ref_pattern.add_node(label='first', type=TestPattern.first_type)
    added_node = ref_pattern.add_node(label='third',
                                      type=TestPattern.third_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)

    assert pattern == ref_pattern
Пример #5
0
def test_join_patterns_func_three_patterns():
    pattern = (TestPattern.first_pattern + TestPattern.second_pattern +
               TestPattern.third_pattern)
    pattern_nodes = list(pattern.graph.nodes)
    third_nodes = list(TestPattern.third_pattern.graph.nodes)
    edges = list(itertools.product(pattern_nodes, third_nodes))
    pattern.join_patterns(TestPattern.third_pattern, edges)
    ref_pattern = GraphPattern()
    _ = ref_pattern.add_node(label='first', type=TestPattern.first_type)
    added_node = ref_pattern.add_node(label='second',
                                      type=TestPattern.second_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)
    last_node = list(nx.topological_sort(ref_pattern.graph))[-1]
    added_node = ref_pattern.add_node(label='third',
                                      type=TestPattern.third_type)
    ref_pattern.add_edge(last_node, added_node)

    added_node = ref_pattern.add_node(label='third',
                                      type=TestPattern.third_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)
    assert ref_pattern == pattern
Пример #6
0
def test_ops_combination_three_patterns():
    pattern = TestPattern.first_pattern + TestPattern.second_pattern | TestPattern.third_pattern
    ref_pattern = GraphPattern()
    ref_pattern.add_node(label='first', type=TestPattern.first_type)
    added_node = ref_pattern.add_node(label='second',
                                      type=TestPattern.second_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)
    _ = ref_pattern.add_node(label='third', type=TestPattern.third_type)
    assert ref_pattern == pattern

    pattern = TestPattern.first_pattern | TestPattern.second_pattern | TestPattern.third_pattern
    ref_pattern = GraphPattern()
    _ = ref_pattern.add_node(label='first', type=TestPattern.first_type)
    _ = ref_pattern.add_node(label='second', type=TestPattern.second_type)
    _ = ref_pattern.add_node(label='third', type=TestPattern.third_type)
    assert ref_pattern == pattern

    pattern = (TestPattern.first_pattern + TestPattern.second_pattern)
    pattern_nodes = list(pattern.graph.nodes)
    third_nodes = list(TestPattern.third_pattern.graph.nodes)
    edges = list(itertools.product(pattern_nodes, third_nodes))
    pattern.join_patterns(TestPattern.third_pattern, edges)
    ref_pattern = GraphPattern()
    ref_pattern.add_node(label='second', type=TestPattern.first_type)
    added_node = ref_pattern.add_node(label='second',
                                      type=TestPattern.second_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)
    added_node = ref_pattern.add_node(label='third',
                                      type=TestPattern.third_type)
    for node in ref_pattern.graph.nodes:
        if node != added_node:
            ref_pattern.add_edge(node, added_node)
    assert ref_pattern == pattern
Пример #7
0
def create_h_swish_act() -> GraphPattern:
    # TODO (vshampor): current approach with join_patterns is deficient since it does not allow to reliably
    #  connect nodes after a pattern has been joined. Along with the label and the type, the nodes created
    #  in the pattern must allow a "name" or "address" attribute, which must be a unique human readable
    #  string identifier of the node even if it has been joined multiple times, or perhaps each pattern
    #  after joining must return a list of output nodes so that these can be joined to later.
    #  Currently cannot specify h_swish in terms of h_sigmoid due to this.

    main_pattern = GraphPattern()

    # ReLU version
    pattern = GraphPattern()
    input_pattern_node = pattern.add_node(
        label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE)

    add_node = pattern.add_node(label='ADD', type='AddV2')
    relu_node = pattern.add_node(label='RELU', type='ReLU')
    mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul')

    pattern.add_edge(input_pattern_node, add_node)
    pattern.add_edge(add_node, relu_node)
    pattern.add_edge(relu_node, mul_node)

    mul_2_node = pattern.add_node(label='MULTIPLY', type='Multiply')
    pattern.add_edge(input_pattern_node, mul_2_node)
    pattern.add_edge(mul_node, mul_2_node)
    main_pattern.add_pattern_alternative(pattern)

    # ReLU6 version
    pattern = GraphPattern()
    input_pattern_node = pattern.add_node(
        label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE)

    add_node = pattern.add_node(label='ADD', type='AddV2')
    relu6_node = pattern.add_node(label='RELU6', type='Relu6')
    mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul')

    pattern.add_edge(input_pattern_node, add_node)
    pattern.add_edge(add_node, relu6_node)
    pattern.add_edge(relu6_node, mul_node)

    mul_2_node = pattern.add_node(label='MULTIPLY', type='Multiply')
    pattern.add_edge(input_pattern_node, mul_2_node)
    pattern.add_edge(mul_node, mul_2_node)
    main_pattern.add_pattern_alternative(pattern)

    return main_pattern