before_op=pattern.final_op,
        )

    pattern.final_op.enclosing_block.replace_uses_of_var_after_op(
        anchor_op=pattern.final_op,
        old_var=pattern.final_op.outputs[0],
        new_var=x)
    # Remove all the ops at once
    pattern.block.remove_ops(pattern.op_list())


if os.getenv("ENABLE_EXPERIMENTAL_PASSES") == "1":
    register_generic_pass(
        ops_arrangement=instancenorm_or_layernorm,
        var_constraints=layernorm_1_constraints,
        transform_pattern=transform_pattern,
        pass_name="fuse_layernorm_or_instancenorm",
        namespace="common",
    )

    register_generic_pass(
        ops_arrangement=instancenorm_or_layernorm,
        var_constraints=instancenorm_1_constraints,
        transform_pattern=transform_pattern,
        pass_name="fuse_layernorm_or_instancenorm",
        namespace="common",
    )

    register_generic_pass(
        ops_arrangement=instancenorm_2,
        var_constraints=instancenorm_2_constraints,
Example #2
0
    if is_deconv:
        x = mb.conv_transpose(**conv_kargs)
    else:
        x = mb.conv(**conv_kargs)

    pattern.scale.enclosing_block.replace_uses_of_var_after_op(
        anchor_op=pattern.scale, old_var=pattern.scale.outputs[0], new_var=x)
    # Remove all the ops at once
    pattern.block.remove_ops(pattern.op_list())


if os.getenv("ENABLE_EXPERIMENTAL_PASSES") == "1":
    register_generic_pass(
        ops_arrangement=conv_scale_mul,
        var_constraints=var_constraints,
        transform_pattern=transform_pattern,
        pass_name="fuse_conv_scale",
        namespace="common",
    )

    register_generic_pass(
        ops_arrangement=conv_transpose_scale_mul,
        var_constraints=var_constraints,
        transform_pattern=transform_pattern,
        pass_name="fuse_conv_scale",
        namespace="common",
    )

    register_generic_pass(
        ops_arrangement=conv_scale_div,
        var_constraints=var_constraints,
    if is_deconv:
        x = mb.conv_transpose(**conv_kargs)
    else:
        x = mb.conv(**conv_kargs)

    pattern.batchnorm.enclosing_block.replace_uses_of_var_after_op(
        anchor_op=pattern.batchnorm,
        old_var=pattern.batchnorm.outputs[0],
        new_var=x)
    # Remove all the ops at once
    pattern.block.remove_ops(pattern.op_list())


if os.getenv('ENABLE_EXPERIMENTAL_PASSES') == '1':
    register_generic_pass(
        ops_arrangement=conv_batchnorm,
        var_constraints=var_constraints,
        transform_pattern=transform_pattern,
        pass_name="fuse_conv_batchnorm",
        namespace="common",
    )

    register_generic_pass(
        ops_arrangement=conv_transpose_batchorm,
        var_constraints=var_constraints,
        transform_pattern=transform_pattern,
        pass_name="fuse_conv_batchnorm",
        namespace="common",
    )
    return passed


def transform_pattern(pattern):
    # remove all the ops, and replace with a gelu op
    out_name = pattern.mul_3.outputs[0].name
    x = mb.gelu(x=pattern.root_var,
                mode="TANH_APPROXIMATION",
                name=out_name,
                before_op=pattern.mul)

    pattern.mul_3.enclosing_block.replace_uses_of_var_after_op(
        anchor_op=pattern.mul_3, old_var=pattern.mul_3.outputs[0], new_var=x)

    # Remove all the ops at once
    pattern.block.remove_ops(pattern.op_list())


if os.getenv('ENABLE_EXPERIMENTAL_PASSES') == '1':
    register_generic_pass(ops_arrangement=gelu_to_detect_1,
                          var_constraints=var_constraints,
                          transform_pattern=transform_pattern,
                          pass_name="fuse_gelu_tanh_approximation",
                          namespace="common")

    register_generic_pass(ops_arrangement=gelu_to_detect_2,
                          var_constraints=var_constraints,
                          transform_pattern=transform_pattern,
                          pass_name="fuse_gelu_tanh_approximation",
                          namespace="common")
def _get_bias_var(pattern):
    if pattern.add_or_sub.op_type == "sub":
        bias_var = pattern.add_or_sub.y
    else:
        bias_var = pattern.add_or_sub.x if pattern.add_or_sub.x.val is not None else pattern.add_or_sub.y

    return bias_var


if os.getenv("ENABLE_EXPERIMENTAL_PASSES") == "1":

    # conv -> add
    register_generic_pass(
        ops_arrangement=pattern_to_detect(False, False, False),
        var_constraints=var_constraints,
        transform_pattern=transform_pattern,
        pass_name="fuse_conv_bias",
        namespace="common",
    )

    # conv -> sub
    register_generic_pass(
        ops_arrangement=pattern_to_detect(False, False, True),
        var_constraints=var_constraints,
        transform_pattern=transform_pattern,
        pass_name="fuse_conv_bias",
        namespace="common",
    )

    # conv_transpose -> add
    register_generic_pass(
Example #6
0
def test_generic_child_ordering():
    """
    Checks that the new generic pattern matching infrastructure works
    regardless of the ordering of an operation's children
    """

    @mb.program(input_specs=[mb.TensorSpec(shape=(3, 5, 6))])
    def prog(x):
        power = mb.pow(x=x, y=3, name="thepowerop")
        add_0 = mb.add(x=power, y=5, name="add_0")
        sub_0 = mb.sub(x=power, y=5, name="sub_0")
        mul_0 = mb.mul(x=power, y=5, name="mul_0")
        add_1 = mb.add(x=add_0, y=mul_0, name="add_1")
        add_2 = mb.add(x=sub_0, y=add_1, name="add_2")
        return add_2

    @mb.program(input_specs=[mb.TensorSpec(shape=(3, 5, 6))])
    def ops_arrangement(x):
        power = mb.pow(x=x, y=3, name="thepowerop")
        sub_0 = mb.sub(x=power, y=5, name="sub_0")
        add_0 = mb.add(x=power, y=5, name="add_0")
        mul_0 = mb.mul(x=power, y=5, name="mul_0")
        add_1 = mb.add(x=add_0, y=mul_0, name="add_1")
        add_2 = mb.add(x=sub_0, y=add_1,name="add_2")
        return add_2

    def var_constraints(pattern):
        constraints_passed = True
        constraints_passed &= _check_var_scalar_value(pattern.thepowerop.y, 3)
        constraints_passed &= _check_var_scalar_value(pattern.sub_0.y, 5)
        constraints_passed &= _check_var_scalar_value(pattern.add_0.x, 5) or _check_var_scalar_value(pattern.add_0.y, 5)
        constraints_passed &=  _check_var_scalar_value(pattern.mul_0.x, 5) or _check_var_scalar_value(pattern.mul_0.y, 5)
        return constraints_passed

    def transform_pattern(pattern):
        out_name = "new operation"
        x = mb.gelu(x=pattern.root_var, mode="TANH_APPROXIMATION", name=out_name, before_op=pattern.thepowerop)

        pattern.add_2.enclosing_block.replace_uses_of_var_after_op(
            anchor_op=pattern.add_2, old_var=pattern.add_2.outputs[0], new_var=x
        )

        pattern.block.remove_ops(pattern.op_list())

    register_generic_pass(ops_arrangement=ops_arrangement, var_constraints=var_constraints,
                          transform_pattern=transform_pattern, pass_name="test_generic_child_ordering",
                          namespace="common")

    prev_prog, prev_block, block = apply_pass_and_basic_check(
        prog, "common::test_generic_child_ordering"
    )
    assert get_op_types_in_program(prev_prog) == [
        "pow",
        "add",
        "sub",
        "mul",
        "add",
        "add",
    ]
    assert get_op_types_in_program(prog) == ["gelu"]
    assert_model_is_valid(
        prog,
        {"x": (3, 5, 6)},
        expected_output_shapes={block.outputs[0].name: (3, 5, 6)},
    )
        k: v
        for k, v in pattern.linear.inputs.items()
        if k not in ["weight", "bias"]
    })

    x = mb.linear(**linear_kargs)

    pattern.add_or_sub.enclosing_block.replace_uses_of_var_after_op(
        anchor_op=pattern.add_or_sub,
        old_var=pattern.add_or_sub.outputs[0],
        new_var=x)
    # Remove all the ops at once
    pattern.block.remove_ops(pattern.op_list())


if os.getenv('ENABLE_EXPERIMENTAL_PASSES') == '1':
    register_generic_pass(
        ops_arrangement=pattern_add,
        var_constraints=var_constraints,
        transform_pattern=transform_pattern,
        pass_name="fuse_linear_bias",
        namespace="common",
    )

    register_generic_pass(
        ops_arrangement=pattern_sub,
        var_constraints=var_constraints,
        transform_pattern=transform_pattern,
        pass_name="fuse_linear_bias",
        namespace="common",
    )