def transform_pattern(pattern):
    """
    Insert instance_norm / layer_norm and delete all ops.
    :param pattern: A pattern object that contains all relevant information.
    """
    out_name = pattern.final_op.outputs[0].name
    axes = pattern.main_reduce.axes.val

    if pattern.requires_rank4_transpose:
        x = mb.transpose(
            x=pattern.main_reduce.x,
            perm=[0, 3, 1, 2],
            name=out_name + "_transpose_nhwc_nchw",
            before_op=pattern.final_op,
        )
    if pattern.is_instancenorm:
        x = mb.instance_norm(
            x=x if pattern.requires_rank4_transpose else pattern.main_reduce.x,
            gamma=np.squeeze(pattern.gamma_var.val),
            beta=np.squeeze(pattern.beta_var.val),
            epsilon=pattern.epsilon_var,
            name=out_name +
            "_instancenorm" if pattern.requires_rank4_transpose else out_name,
            before_op=pattern.final_op,
        )
    else:  # is_layernorm
        x = mb.layer_norm(
            x=x if pattern.requires_rank4_transpose else pattern.main_reduce.x,
            axes=axes,
            gamma=pattern.gamma_var,
            beta=pattern.beta_var,
            epsilon=pattern.epsilon_var,
            name=out_name +
            "_layernorm" if pattern.requires_rank4_transpose else out_name,
            before_op=pattern.final_op,
        )
    if pattern.requires_rank4_transpose:
        x = mb.transpose(
            x=x,
            perm=[0, 2, 3, 1],
            name=out_name + "_transpose_nchw_nhwc",
            before_op=pattern.final_op,
        )

    pattern.final_op.enclosing_block.replace_uses_of_var_after_op(
        anchor_op=pattern.final_op,
        old_var=pattern.final_op.outputs[0],
        new_var=x)
    # Remove all the ops at once
    pattern.block.remove_ops(pattern.op_list())
 def prog(x):
     x = mb.relu(x=x, name="relu")
     x = mb.transpose(x=x, perm=[0, 3, 1, 2], name="transpose")
     x = mb.reduce_mean(x=x, axes=[2, 3], keep_dims=False, name="reduce")
     x = mb.log(x=x, name="log")
     y = mb.add(x=1, y=2)
     return x
    def conv_bias_pattern(x):
        if not conv_transpose:
            conv = mb.conv(x=x,
                           weight=arbitrary_weight,
                           pad_type="valid",
                           name="conv")
        else:
            conv = mb.conv_transpose(x=x,
                                     weight=arbitrary_weight,
                                     pad_type="valid",
                                     name="conv")

        if transpose:
            transpose_layer = mb.transpose(x=conv,
                                           perm=arbitrary_perm,
                                           name="transpose")

        if sub:
            add_or_sub = mb.sub(x=transpose_layer if transpose else conv,
                                y=arbitrary_scalar,
                                name="add_or_sub")
        else:
            add_or_sub = mb.add(x=transpose_layer if transpose else conv,
                                y=arbitrary_scalar,
                                name="add_or_sub")
        return add_or_sub
Example #4
0
def transform_transpose_pattern(pattern):
    is_deconv = pattern.conv.op_type == "conv_transpose"

    # get the bias
    bias = pattern.add_or_sub.x.val if pattern.add_or_sub.x.val is not None else pattern.add_or_sub.y.val
    is_first_input = pattern.add_or_sub.y.val is not None
    is_sub = pattern.add_or_sub.op_type == "sub"

    # get the conv bias/weight
    conv_shape = pattern.conv.outputs[0].shape
    Cout = conv_shape[1]
    conv_weight = pattern.conv.weight.val
    conv_weight_type = conv_weight.dtype
    conv_bias = np.zeros(Cout).astype(conv_weight_type) if pattern.conv.bias is None else pattern.conv.bias.val

    bias = _bias_mod_and_validity(bias, Cout, pattern)

    # compute the new bias
    if is_sub:
        if is_first_input:
            bias = -bias
        else:
            conv_bias = -conv_bias

    new_bias = conv_bias + bias

    # compute the new weight
    if is_sub and not is_first_input:
        new_weight = -conv_weight
    else:
        new_weight = conv_weight

    # create a new conv op with the new weight, bias value, copying rest of the attributes
    conv_kargs = {"weight": new_weight, "bias": new_bias, "before_op": pattern.conv}

    for k, v in pattern.conv.inputs.items():
        if k in ["weight", "bias"]:
            continue
        conv_kargs[k] = v

    if is_deconv:
        x = mb.conv_transpose(**conv_kargs)
    else:
        x = mb.conv(**conv_kargs)

    # create a new transpose op
    out_name = pattern.add_or_sub.outputs[0].name
    tranpose_kargs = {"x": x, "name": out_name, "before_op": pattern.transpose}
    for k, v in pattern.transpose.inputs.items():
        if k == "x":
            continue
        tranpose_kargs[k] = v
    x = mb.transpose(**tranpose_kargs)

    pattern.add_or_sub.enclosing_block.replace_uses_of_var_after_op(
        anchor_op=pattern.add_or_sub, old_var=pattern.add_or_sub.outputs[0], new_var=x
    )

    # Remove all the ops at once
    pattern.block.remove_ops(pattern.op_list())
Example #5
0
    def transform_pattern(pattern):
        # remove all the ops, and replace with a prelu op + transpose op
        perm = pattern.transpose.perm.val
        out_var = pattern.out_op.outputs[0]
        if pattern.alpha_mul.x.val is not None:
            alpha = pattern.alpha_mul.x.val
        else:
            alpha = pattern.alpha_mul.y.val

        alpha_vector = -1 * alpha.flatten()
        x = mb.prelu(x=pattern.root_var, alpha=alpha_vector, before_op=pattern.out_op)
        x = mb.transpose(x=x, perm=perm, name=out_var.name, before_op=pattern.out_op)
        pattern.out_op.enclosing_block.replace_uses_of_var_after_op(
            anchor_op=pattern.out_op, old_var=out_var, new_var=x
        )
        # Remove all the ops at once
        pattern.block.remove_ops(pattern.op_list())
Example #6
0
 def prelu_pattern(x):
     # perm value can be anything, it will be checked in "is_var_constraint_satisifed" method
     x = mb.transpose(x=x, perm=[0,1,2,3], name="transpose")
     return _prelu_pattern(x)