Example #1
0
def create_h_sigmoid_act() -> GraphPattern:
    main_pattern = GraphPattern()

    # ReLU version
    pattern = GraphPattern()

    input_pattern_node = pattern.add_node(
        label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE)
    add_node = pattern.add_node(label='ADD', type='AddV2')
    relu_node = pattern.add_node(label='RELU', type='ReLU')
    mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul')

    pattern.add_edge(input_pattern_node, add_node)
    pattern.add_edge(add_node, relu_node)
    pattern.add_edge(relu_node, mul_node)

    main_pattern.add_pattern_alternative(pattern)

    # ReLU6 version

    pattern = GraphPattern()

    input_pattern_node = pattern.add_node(
        label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE)
    add_node = pattern.add_node(label='ADD', type='AddV2')
    relu6_node = pattern.add_node(label='RELU6', type='Relu6')
    mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul')

    pattern.add_edge(input_pattern_node, add_node)
    pattern.add_edge(add_node, relu6_node)
    pattern.add_edge(relu6_node, mul_node)

    main_pattern.add_pattern_alternative(pattern)

    return main_pattern
Example #2
0
def create_h_swish_act() -> GraphPattern:
    # TODO (vshampor): current approach with join_patterns is deficient since it does not allow to reliably
    #  connect nodes after a pattern has been joined. Along with the label and the type, the nodes created
    #  in the pattern must allow a "name" or "address" attribute, which must be a unique human readable
    #  string identifier of the node even if it has been joined multiple times, or perhaps each pattern
    #  after joining must return a list of output nodes so that these can be joined to later.
    #  Currently cannot specify h_swish in terms of h_sigmoid due to this.

    main_pattern = GraphPattern()

    # ReLU version
    pattern = GraphPattern()
    input_pattern_node = pattern.add_node(
        label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE)

    add_node = pattern.add_node(label='ADD', type='AddV2')
    relu_node = pattern.add_node(label='RELU', type='ReLU')
    mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul')

    pattern.add_edge(input_pattern_node, add_node)
    pattern.add_edge(add_node, relu_node)
    pattern.add_edge(relu_node, mul_node)

    mul_2_node = pattern.add_node(label='MULTIPLY', type='Multiply')
    pattern.add_edge(input_pattern_node, mul_2_node)
    pattern.add_edge(mul_node, mul_2_node)
    main_pattern.add_pattern_alternative(pattern)

    # ReLU6 version
    pattern = GraphPattern()
    input_pattern_node = pattern.add_node(
        label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE)

    add_node = pattern.add_node(label='ADD', type='AddV2')
    relu6_node = pattern.add_node(label='RELU6', type='Relu6')
    mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul')

    pattern.add_edge(input_pattern_node, add_node)
    pattern.add_edge(add_node, relu6_node)
    pattern.add_edge(relu6_node, mul_node)

    mul_2_node = pattern.add_node(label='MULTIPLY', type='Multiply')
    pattern.add_edge(input_pattern_node, mul_2_node)
    pattern.add_edge(mul_node, mul_2_node)
    main_pattern.add_pattern_alternative(pattern)

    return main_pattern