def create_h_sigmoid_act() -> GraphPattern: main_pattern = GraphPattern() # ReLU version pattern = GraphPattern() input_pattern_node = pattern.add_node( label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE) add_node = pattern.add_node(label='ADD', type='AddV2') relu_node = pattern.add_node(label='RELU', type='ReLU') mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul') pattern.add_edge(input_pattern_node, add_node) pattern.add_edge(add_node, relu_node) pattern.add_edge(relu_node, mul_node) main_pattern.add_pattern_alternative(pattern) # ReLU6 version pattern = GraphPattern() input_pattern_node = pattern.add_node( label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE) add_node = pattern.add_node(label='ADD', type='AddV2') relu6_node = pattern.add_node(label='RELU6', type='Relu6') mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul') pattern.add_edge(input_pattern_node, add_node) pattern.add_edge(add_node, relu6_node) pattern.add_edge(relu6_node, mul_node) main_pattern.add_pattern_alternative(pattern) return main_pattern
def create_h_swish_act() -> GraphPattern: # TODO (vshampor): current approach with join_patterns is deficient since it does not allow to reliably # connect nodes after a pattern has been joined. Along with the label and the type, the nodes created # in the pattern must allow a "name" or "address" attribute, which must be a unique human readable # string identifier of the node even if it has been joined multiple times, or perhaps each pattern # after joining must return a list of output nodes so that these can be joined to later. # Currently cannot specify h_swish in terms of h_sigmoid due to this. main_pattern = GraphPattern() # ReLU version pattern = GraphPattern() input_pattern_node = pattern.add_node( label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE) add_node = pattern.add_node(label='ADD', type='AddV2') relu_node = pattern.add_node(label='RELU', type='ReLU') mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul') pattern.add_edge(input_pattern_node, add_node) pattern.add_edge(add_node, relu_node) pattern.add_edge(relu_node, mul_node) mul_2_node = pattern.add_node(label='MULTIPLY', type='Multiply') pattern.add_edge(input_pattern_node, mul_2_node) pattern.add_edge(mul_node, mul_2_node) main_pattern.add_pattern_alternative(pattern) # ReLU6 version pattern = GraphPattern() input_pattern_node = pattern.add_node( label='*INPUT_NODE*', type=GraphPattern.NON_PATTERN_NODE_TYPE) add_node = pattern.add_node(label='ADD', type='AddV2') relu6_node = pattern.add_node(label='RELU6', type='Relu6') mul_node = pattern.add_node(label='TF_OP_MUL', type='Mul') pattern.add_edge(input_pattern_node, add_node) pattern.add_edge(add_node, relu6_node) pattern.add_edge(relu6_node, mul_node) mul_2_node = pattern.add_node(label='MULTIPLY', type='Multiply') pattern.add_edge(input_pattern_node, mul_2_node) pattern.add_edge(mul_node, mul_2_node) main_pattern.add_pattern_alternative(pattern) return main_pattern