def transform_pattern(pattern): """ Insert instance_norm / layer_norm and delete all ops. :param pattern: A pattern object that contains all relevant information. """ out_name = pattern.final_op.outputs[0].name axes = pattern.main_reduce.axes.val if pattern.requires_rank4_transpose: x = mb.transpose( x=pattern.main_reduce.x, perm=[0, 3, 1, 2], name=out_name + "_transpose_nhwc_nchw", before_op=pattern.final_op, ) if pattern.is_instancenorm: x = mb.instance_norm( x=x if pattern.requires_rank4_transpose else pattern.main_reduce.x, gamma=np.squeeze(pattern.gamma_var.val), beta=np.squeeze(pattern.beta_var.val), epsilon=pattern.epsilon_var, name=out_name + "_instancenorm" if pattern.requires_rank4_transpose else out_name, before_op=pattern.final_op, ) else: # is_layernorm x = mb.layer_norm( x=x if pattern.requires_rank4_transpose else pattern.main_reduce.x, axes=axes, gamma=pattern.gamma_var, beta=pattern.beta_var, epsilon=pattern.epsilon_var, name=out_name + "_layernorm" if pattern.requires_rank4_transpose else out_name, before_op=pattern.final_op, ) if pattern.requires_rank4_transpose: x = mb.transpose( x=x, perm=[0, 2, 3, 1], name=out_name + "_transpose_nchw_nhwc", before_op=pattern.final_op, ) pattern.final_op.enclosing_block.replace_uses_of_var_after_op( anchor_op=pattern.final_op, old_var=pattern.final_op.outputs[0], new_var=x) # Remove all the ops at once pattern.block.remove_ops(pattern.op_list())
def prog(x): x = mb.relu(x=x, name="relu") x = mb.transpose(x=x, perm=[0, 3, 1, 2], name="transpose") x = mb.reduce_mean(x=x, axes=[2, 3], keep_dims=False, name="reduce") x = mb.log(x=x, name="log") y = mb.add(x=1, y=2) return x
def conv_bias_pattern(x): if not conv_transpose: conv = mb.conv(x=x, weight=arbitrary_weight, pad_type="valid", name="conv") else: conv = mb.conv_transpose(x=x, weight=arbitrary_weight, pad_type="valid", name="conv") if transpose: transpose_layer = mb.transpose(x=conv, perm=arbitrary_perm, name="transpose") if sub: add_or_sub = mb.sub(x=transpose_layer if transpose else conv, y=arbitrary_scalar, name="add_or_sub") else: add_or_sub = mb.add(x=transpose_layer if transpose else conv, y=arbitrary_scalar, name="add_or_sub") return add_or_sub
def transform_transpose_pattern(pattern): is_deconv = pattern.conv.op_type == "conv_transpose" # get the bias bias = pattern.add_or_sub.x.val if pattern.add_or_sub.x.val is not None else pattern.add_or_sub.y.val is_first_input = pattern.add_or_sub.y.val is not None is_sub = pattern.add_or_sub.op_type == "sub" # get the conv bias/weight conv_shape = pattern.conv.outputs[0].shape Cout = conv_shape[1] conv_weight = pattern.conv.weight.val conv_weight_type = conv_weight.dtype conv_bias = np.zeros(Cout).astype(conv_weight_type) if pattern.conv.bias is None else pattern.conv.bias.val bias = _bias_mod_and_validity(bias, Cout, pattern) # compute the new bias if is_sub: if is_first_input: bias = -bias else: conv_bias = -conv_bias new_bias = conv_bias + bias # compute the new weight if is_sub and not is_first_input: new_weight = -conv_weight else: new_weight = conv_weight # create a new conv op with the new weight, bias value, copying rest of the attributes conv_kargs = {"weight": new_weight, "bias": new_bias, "before_op": pattern.conv} for k, v in pattern.conv.inputs.items(): if k in ["weight", "bias"]: continue conv_kargs[k] = v if is_deconv: x = mb.conv_transpose(**conv_kargs) else: x = mb.conv(**conv_kargs) # create a new transpose op out_name = pattern.add_or_sub.outputs[0].name tranpose_kargs = {"x": x, "name": out_name, "before_op": pattern.transpose} for k, v in pattern.transpose.inputs.items(): if k == "x": continue tranpose_kargs[k] = v x = mb.transpose(**tranpose_kargs) pattern.add_or_sub.enclosing_block.replace_uses_of_var_after_op( anchor_op=pattern.add_or_sub, old_var=pattern.add_or_sub.outputs[0], new_var=x ) # Remove all the ops at once pattern.block.remove_ops(pattern.op_list())
def transform_pattern(pattern): # remove all the ops, and replace with a prelu op + transpose op perm = pattern.transpose.perm.val out_var = pattern.out_op.outputs[0] if pattern.alpha_mul.x.val is not None: alpha = pattern.alpha_mul.x.val else: alpha = pattern.alpha_mul.y.val alpha_vector = -1 * alpha.flatten() x = mb.prelu(x=pattern.root_var, alpha=alpha_vector, before_op=pattern.out_op) x = mb.transpose(x=x, perm=perm, name=out_var.name, before_op=pattern.out_op) pattern.out_op.enclosing_block.replace_uses_of_var_after_op( anchor_op=pattern.out_op, old_var=out_var, new_var=x ) # Remove all the ops at once pattern.block.remove_ops(pattern.op_list())
def prelu_pattern(x): # perm value can be anything, it will be checked in "is_var_constraint_satisifed" method x = mb.transpose(x=x, perm=[0,1,2,3], name="transpose") return _prelu_pattern(x)