def softmax_relu_pass(): x = Any() softmax_pattern = Prim(P.Softmax()) pattern = Call(softmax_pattern, [x]) sigmoid_pattern = Prim(P.Sigmoid()) call_sigmoid = Call(sigmoid_pattern, [x]) relu_pattern = Prim(P.ReLU()) target = Call(relu_pattern, [call_sigmoid]) return pattern, target
def softmax_neg_pass(): x = Any() softmax_pattern = Prim(P.Softmax()) call_softmax = Call(softmax_pattern, [x]) relu_pattern = Prim(P.ReLU()) call_relu = Call(relu_pattern, [x]) pattern = OneOf([call_softmax, call_relu]) neg_ops = Prim(P.Neg()) target = Call(neg_ops, [pattern]) return pattern, target
def softmax_relu_pass(): x = Any() softmax_pattern = Prim(P.Softmax()) call_softmax = Call(softmax_pattern, [x]) relu_pattern = Prim(P.ReLU()) call_relu = Call(relu_pattern, [x]) pattern = OneOf([call_softmax, call_relu]) relu6_pattern = Prim(P.ReLU6()) target = Call(relu6_pattern, [x]) return pattern, target
def single_bn_pass(): """ Sub a BN which does NOT take Conv as inputs to ReLU6. """ conv2d_prim = Prim("Conv2D") conv2d = Call(conv2d_prim) pattern_0 = NoneOf(conv2d) pattern = Call(P.BatchNorm(), [pattern_0]) target = Call(P.ReLU6(), [pattern_0]) return pattern, target
def single_bn_pass(): """ Sub a BN which does NOT take MatMul as inputs to ReLU6. """ matmul = Prim("MatMul") pattern_0 = NoneOf(matmul) softmax = P.Softmax() pattern = Call(softmax, [pattern_0]) relu6 = P.ReLU6() target = Call(relu6, [pattern_0]) return pattern, target
def softmax_relu_pass(): x = Any() sigmoid_softmax_pattern = Prim([P.Sigmoid(), P.Softmax()]) pattern = Call(sigmoid_softmax_pattern, [x]) target = Call(P.ReLU(), [x]) return pattern, target