def before_nchw():
     x = relay.var("x", shape=(1, 64, 56, 56))
     weight1 = relay.var("weight1")
     y = relay.nn.conv2d(x, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1))
     ret = relay.sum(y, axis=1, keepdims=True)
     y = relay.Function(analysis.free_vars(ret), ret)
     return y
示例#2
0
def verify_sum_grad(d_shape, axis=None, keepdims=False, exclude=False):
    data = relay.var("data", relay.TensorType(d_shape, "float32"))
    fwd_func = relay.Function([data],
                              relay.sum(data,
                                        axis=axis,
                                        keepdims=keepdims,
                                        exclude=exclude))
    check_grad(fwd_func)
 def before_nhwc():
     x = relay.var("x", shape=(1, 56, 56, 64))
     weight1 = relay.var("weight1")
     y = relay.nn.conv2d(
         x, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NHWC"
     )
     ret = relay.sum(y, axis=3, keepdims=True)
     y = relay.Function(analysis.free_vars(ret), ret)
     return y
 def expected_nchw():
     x = relay.var("x", shape=(1, 64, 56, 56))
     weight1 = relay.var("weight1")
     y = relay.layout_transform(x, "NCHW", "NCHW16c")
     y = relay.nn.conv2d(
         y, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
     )
     ret = relay.layout_transform(y, "NCHW16c", "NCHW")
     ret = relay.sum(ret, axis=[1], keepdims=True)
     y = relay.Function(analysis.free_vars(ret), ret)
     return y
示例#5
0
def test_broadcast_to_const_shape_int64():
    shape_like = relay.const(np.array([1, 5]), dtype="int64")
    x = relay.var("x", shape=(1,), dtype="int64")
    z = relay.broadcast_to(x, shape=shape_like)
    z = relay.sum(z, axis=0)

    f = relay.Function([x], z)

    x = np.random.randint(10, size=(1,), dtype="int64")
    ref_res = np.broadcast_to(x, (5,))
    for target, dev in tvm.testing.enabled_targets():
        for kind in ["graph", "debug"]:
            op_res = relay.create_executor(kind, device=dev, target=target).evaluate(f)(x)
            tvm.testing.assert_allclose(op_res.numpy(), ref_res)
示例#6
0
def relay_take_grad_inp(c, _nb_indices, _indices, _values):
    assert _nb_indices.is_constant(int)
    values = c.ref(_values)
    r_indices = relay.reshape(c.ref(_indices),
                              tuple(_indices.abstract.xshape()) + (1, ))
    n_rows = _nb_indices.value
    n_cols = _values.abstract.xshape()[-1]
    outputs = []
    indices_dtype = type_to_np_dtype(_indices.abstract.element.xtype())
    out_dtype = type_to_np_dtype(_values.abstract.element.xtype())
    for i in range(n_rows):
        select_entries = relay.equal(r_indices, relay.const(i, indices_dtype))
        casted_select = relay.cast(select_entries, out_dtype)
        select_dout = relay.multiply(casted_select, values)
        reshape_out = relay.reshape(select_dout, (-1, n_cols))
        vector = relay.sum(reshape_out, 0)
        outputs.append(relay.reshape(vector, (1, n_cols)))
    return relay.concatenate(outputs, 0)
示例#7
0
def vnni_legalize(inputs, arg_types, op, attrs, need_expand=False):
    """Legalizes s8, s8 -> s32 GEMM op for VNNI."""
    if check_vnni_applicable(arg_types[0],
                             arg_types[1]) and arg_types[0].dtype == "int8":
        x, y = inputs
        x = relay.cast(x, "int32")
        x = relay.add(x, relay.const(128, "int32"))
        x = relay.cast(x, "uint8")

        adjust_shift = relay.const(128, "int32") * relay.sum(
            relay.cast(y, "int32"), axis=[-1])

        if need_expand:
            adjust_shift = relay.expand_dims(adjust_shift, axis=1)

        out = op(x, y, **attrs)

        return relay.subtract(out, adjust_shift)

    return None
示例#8
0
文件: relay.py 项目: GonChen/myia
def relay_conv2d_weight_grad(c, data, wsize, dout, stride, pad, dil, groups):
    assert wsize.is_constant(tuple)
    assert stride.is_constant(tuple)
    assert pad.is_constant(tuple)
    assert dil.is_constant(tuple)
    assert groups.is_constant(int)

    batch, in_channel, in_h, in_w = data.abstract.xshape()
    out_channel, _, filter_h, filter_w = wsize.value
    _, _, grad_h, grad_w = dout.abstract.xshape()
    pad_h, pad_w = pad.value

    data = c.ref(data)
    dout = c.ref(dout)

    fpad_h = pad_h * 2
    fpad_w = pad_w * 2
    fpad_top = (pad_h + 1) // 2
    fpad_left = (pad_w + 1) // 2
    fpad_bottom = fpad_h - fpad_top
    fpad_right = fpad_w - fpad_left

    padded_weight_grad_h = ((in_h - (grad_h - 1) * stride.value[0] - 1 +
                             fpad_top + fpad_bottom) // dil.value[0] + 1)
    padded_weight_grad_w = ((in_w - (grad_w - 1) * stride.value[1] - 1 +
                             fpad_left + fpad_right) // dil.value[1] + 1)

    dout = relay.tile(dout, [1, in_channel // groups.value, 1, 1])
    dout = relay.reshape(dout, [-1, 1, 0, 0])
    data = relay.reshape(data, [1, -1, 0, 0])

    d = relay.nn.conv2d(data, dout, strides=dil.value, padding=pad.value,
                        dilation=stride.value, groups=batch * in_channel)
    d = relay.reshape(d, [batch, in_channel // groups.value, out_channel,
                          padded_weight_grad_h, padded_weight_grad_w])
    d = relay.sum(d, axis=0)
    d = relay.transpose(d, [1, 0, 2, 3])
    if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
        d = relay.strided_slice(d, begin=[0, 0, 0, 0],
                                end=[None, None, filter_h, filter_w])
    return d
示例#9
0
def relay_conv2d_weight_grad(c, data, wsize, dout, stride, pad, dil, groups):
    # This implementation should match the one in pytorch backend
    # (myia.compile.backends.pytorch_conv_grad.conv2d_weight)

    assert wsize.is_constant(tuple)
    assert stride.is_constant(tuple)
    assert pad.is_constant(tuple)
    assert dil.is_constant(tuple)
    assert groups.is_constant(int)

    batch, in_channel, in_h, in_w = data.abstract.xshape()
    out_channel, _, filter_h, filter_w = wsize.value
    grad_sh0, grad_sh1, grad_h, grad_w = dout.abstract.xshape()
    pad_h, pad_w = pad.value

    data = c.ref(data)
    dout = c.ref(dout)

    fpad_h = pad_h * 2
    fpad_w = pad_w * 2
    fpad_top = (pad_h + 1) // 2
    fpad_left = (pad_w + 1) // 2
    fpad_bottom = fpad_h - fpad_top
    fpad_right = fpad_w - fpad_left

    padded_weight_grad_h = (in_h - (grad_h - 1) * stride.value[0] - 1 +
                            fpad_top + fpad_bottom) // dil.value[0] + 1
    padded_weight_grad_w = (in_w - (grad_w - 1) * stride.value[1] - 1 +
                            fpad_left + fpad_right) // dil.value[1] + 1

    dout = relay.tile(dout, [1, in_channel // groups.value, 1, 1])
    dout = relay.reshape(dout, [-1, 1, 0, 0])
    data = relay.reshape(data, [1, -1, 0, 0])

    d = relay.nn.conv2d(
        data,
        dout,
        strides=dil.value,
        padding=pad.value,
        dilation=stride.value,
        groups=batch * in_channel,
    )

    conv_sh1 = grad_sh0 * grad_sh1 * (in_channel // groups.value)
    d = relay.reshape(
        d,
        [batch, conv_sh1 // batch, padded_weight_grad_h, padded_weight_grad_w],
    )
    d = relay.sum(d, axis=0)

    if groups.value > 1:
        d = relay.reshape(
            d,
            [
                grad_sh1,
                in_channel // groups.value,
                padded_weight_grad_h,
                padded_weight_grad_w,
            ],
        )
    else:
        d = relay.reshape(
            d,
            [
                in_channel // groups.value,
                grad_sh1,
                padded_weight_grad_h,
                padded_weight_grad_w,
            ],
        )
        d = relay.transpose(d, [1, 0, 2, 3])

    if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
        d = relay.strided_slice(d,
                                begin=[0, 0, 0, 0],
                                end=[None, None, filter_h, filter_w])
    return d
示例#10
0
def _conv2d_legalize(attrs, inputs, arg_types):
    """Legalizes Conv2D op.

    Parameters
    ----------
    attrs : tvm.ir.Attrs
        Attributes of current convolution
    inputs : list of tvm.relay.Expr
        The args of the Relay expr to be legalized
    types : list of types
        List of input and output types

    Returns
    -------
    result : tvm.relay.Expr
        The legalized expr
    """

    # Dilation not supported yet. Return None if dilation is not (1, 1)
    dilation = attrs.get_int_tuple("dilation")
    if not (dilation[0] == 1 and dilation[1] == 1):
        return None

    # No legalization for depthwise convolutions yet.
    groups = attrs.get_int("groups")
    if groups != 1:
        return None

    # Collect the input tensors.
    data_tensor, kernel_tensor = arg_types[0], arg_types[1]
    data_dtype = data_tensor.dtype
    kernel_dtype = kernel_tensor.dtype

    # Collect the output tensor.
    output_tensor = arg_types[2]

    # Collect the input exprs.
    data, kernel = inputs

    # Get the conv attrs
    new_attrs = {k: attrs[k] for k in attrs.keys()}

    is_int8_inputs = False
    # If both the inputs are int8, we can add 128 to make the input dtype uint8, and then adjust the
    # output. This will help picking up Intel VNNI instructions.
    # Original --> C = A (conv) B
    # A and B are int8
    #   C = (A + 128 - 128) (conv) B
    #   C = (A' conv B) - 128 (conv) B
    # where A' = A + 128
    # and 128 (conv) B is basically a reduce on CRS axis for weights.
    if data_tensor.dtype == "int8" and kernel_tensor.dtype == "int8":
        is_int8_inputs = True
        padding = attrs.get_int_tuple("padding")
        kh, kw = attrs.get_int_tuple("kernel_size")
        pt, pl, pb, pr = get_pad_tuple(padding, (kh, kw))

        if attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWIO":
            adjust_shift = relay.sum(relay.cast(kernel, dtype="int32"),
                                     axis=(0, 1, 2))
            pad_width = ((0, 0), (pt, pb), (pl, pr), (0, 0))
        elif attrs["data_layout"] == "NCHW" and attrs[
                "kernel_layout"] == "OIHW":
            pad_width = ((0, 0), (0, 0), (pt, pb), (pl, pr))
            adjust_shift = relay.sum(relay.cast(kernel, dtype="int32"),
                                     axis=(1, 2, 3))
            adjust_shift = relay.expand_dims(adjust_shift,
                                             axis=1,
                                             num_newaxis=2)
        else:
            return None

        data = relay.cast(data, "int32")
        data = relay.add(data, relay.const(128, "int32"))
        data = relay.cast(data, "uint8")

        # Do external padding as pad value has to be 128.
        if any(padding):
            data = relay.nn.pad(data, pad_width=pad_width, pad_value=128)
        new_attrs["padding"] = (0, 0)

        # The data type is now shifted to uint8
        data_dtype = "uint8"

        # Multiply 128 to adjust shift.
        adjust_shift = relay.multiply(adjust_shift, relay.const(128, "int32"))

    # Legalize if the datatypes are suitable for fast Int8 instructions.  Int8 instructions require
    # input channel to be a multiple of 4 and output channels to be a multiple of 16. For input
    # channels, we pad both the inputs and weights input channels. For output channels, we pad the
    # weight and stride_slice the output.
    if is_int8_hw_support(data_dtype, kernel_dtype):
        # Flags to remember if the expr is modified
        ic_modified = False
        oc_modified = False

        # Find the value of input and output channel.
        in_channel = -1
        out_channel = -1
        if attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWIO":
            in_channel = data_tensor.shape[3].value
            out_channel = kernel_tensor.shape[3].value
        elif attrs["data_layout"] == "NCHW" and attrs[
                "kernel_layout"] == "OIHW":
            in_channel = data_tensor.shape[1].value
            out_channel = kernel_tensor.shape[0].value
        else:
            return None

        if in_channel % 4 != 0:
            new_in_channel = ((in_channel + 4) // 4) * 4
            diff = new_in_channel - in_channel
            if attrs["data_layout"] == "NHWC" and attrs[
                    "kernel_layout"] == "HWIO":
                data = relay.nn.pad(data,
                                    pad_width=((0, 0), (0, 0), (0, 0), (0,
                                                                        diff)))
                kernel = relay.nn.pad(kernel,
                                      pad_width=((0, 0), (0, 0), (0, diff),
                                                 (0, 0)))
                ic_modified = True
            elif attrs["data_layout"] == "NCHW" and attrs[
                    "kernel_layout"] == "OIHW":
                pad_width = ((0, 0), (0, diff), (0, 0), (0, 0))
                data = relay.nn.pad(data, pad_width=pad_width)
                kernel = relay.nn.pad(kernel, pad_width=pad_width)
                ic_modified = True
            else:
                return None

        new_out_channel = out_channel
        if out_channel % 16 != 0:
            new_out_channel = ((out_channel + 16) // 16) * 16
            diff = new_out_channel - out_channel
            if attrs["data_layout"] == "NHWC" and attrs[
                    "kernel_layout"] == "HWIO":
                kernel = relay.nn.pad(kernel,
                                      pad_width=((0, 0), (0, 0), (0, 0),
                                                 (0, diff)))
                oc_modified = True
            elif attrs["data_layout"] == "NCHW" and attrs[
                    "kernel_layout"] == "OIHW":
                kernel = relay.nn.pad(kernel,
                                      pad_width=((0, diff), (0, 0), (0, 0),
                                                 (0, 0)))
                oc_modified = True
            else:
                return None

        if oc_modified:
            new_attrs["channels"] = new_out_channel
            out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
            original_out_shape = [x.value for x in output_tensor.shape]
            out = relay.strided_slice(out,
                                      begin=[0, 0, 0, 0],
                                      end=original_out_shape)
        else:
            out = relay.nn.conv2d(data, kernel, **new_attrs)

        if is_int8_inputs:
            out = relay.subtract(out, adjust_shift)

        return out
    return None
示例#11
0
import tvm
from tvm import relay
import sys
import numpy as np

w_file = sys.argv[1]
h_file = sys.argv[2]

w_data = np.load(w_file).astype('float32')
h_data = np.load(h_file).astype('float32')

m = w_data.shape[0]
n = h_data.shape[1]
r = w_data.shape[1]

w = relay.var('w', shape=(m, r), dtype='float32')
h = relay.var('h', shape=(r, n), dtype='float32')
program = relay.nn.dense(w, h)
program = relay.sum(program, axis=None)
program = relay.Function([w, h], program)
module = relay.Module.from_expr(program)

_, tvm_module, _ = relay.build(module, 'llvm')

timer = tvm_module.time_evaluator(tvm_module.entry_name, tvm.cpu(0))
# TODO why is it expecting this output size?
output_tvm = tvm.nd.array(np.empty((m, r)).astype('float32'))
res = timer(tvm.nd.array(w_data), tvm.nd.array(h_data), output_tvm)
print(res)
示例#12
0
def conv2d_alter_int8_common(
    data,
    data_tensor,
    kernel,
    kernel_tensor,
    output_tensor,
    attrs,
    data_dtype: str,
    in_channel_vector_length: int,
    out_channel_vector_length: int,
):
    """
    Convert TE inputs/outputs so that they are suitable for fast Int8 instructions.

    Int8 instructions require input channels and output channels to be a
    multiple of the vector length. For input channels, we pad both the inputs
    and weights channels. For output channels, we pad the weight and
    stride_slice the output.

    Arguments
    ---------
    data: Expr
        Data Expr
    data_tensor: Tensor
        Data tensor
    kernel: Expr
        Kernel Expr
    kernel_tensor: Tensor
        Kernel tensor
    output_tensor: Tensor
        Output tensor
    attrs: Conv2dAttrs
        Attributes of the computation
    data_dtype: "int8" or "uint8"
        Desired dtype of data. Data will be converted to this dtype before the main computation.
    in_channel_vector_length: int
        Length of vector units on target hardware. Input channels are padded to this length.
    out_channel_vector_length: int
        Output size of vector instruction. Output channels are padded to this length.

    Returns
    -------
    out : Tensor
        Conv2d computation with inputs in the correct order for tensorization.
    """
    # Dilation not supported yet. Return None if dilation is not (1, 1)
    dilation = attrs.get_int_tuple("dilation")
    if not (dilation[0] == 1 and dilation[1] == 1):
        return None

    # No legalization for depthwise convolutions yet.
    groups = attrs.get_int("groups")
    if groups != 1:
        return None

    # Get the conv attrs
    new_attrs = {k: attrs[k] for k in attrs.keys()}

    padding = attrs.get_int_tuple("padding")
    kh, kw = attrs.get_int_tuple("kernel_size")
    pt, pl, pb, pr = get_pad_tuple(padding, (kh, kw))

    if data_tensor.dtype != data_dtype:
        # How to convert data to int8
        # Original --> C = A (conv) B
        # A and B are int8
        #   C = (A + 128 - 128) (conv) B
        #   C = (A' conv B) - 128 (conv) B
        # where A' = A + 128
        # and 128 (conv) B is basically a reduce on CRS axis for weights.
        #
        # How to convert data to uint8
        #   C = (A - 128 + 128) (conv) B
        #   C = (A' conv B) + 128 (conv) B
        # where A' = A - 128
        if data_dtype == "int8":
            # shift data to int8
            before_shift = relay.add
            after_shift = relay.subtract
        else:
            # shift data to uint8
            before_shift = relay.subtract
            after_shift = relay.add

        if attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWIO":
            adjust_shift = relay.sum(relay.cast(kernel, dtype="int32"),
                                     axis=(0, 1, 2))
            pad_width = ((0, 0), (pt, pb), (pl, pr), (0, 0))
        elif attrs["data_layout"] == "NCHW" and attrs[
                "kernel_layout"] == "OIHW":
            pad_width = ((0, 0), (0, 0), (pt, pb), (pl, pr))
            adjust_shift = relay.sum(relay.cast(kernel, dtype="int32"),
                                     axis=(1, 2, 3))
            adjust_shift = relay.expand_dims(adjust_shift,
                                             axis=1,
                                             num_newaxis=2)
        else:
            return None

        data = relay.cast(data, "int32")
        data = before_shift(data, relay.const(128, "int32"))
        data = relay.cast(data, data_dtype)

        # Do external padding as pad value has to be 128.
        if any(padding):
            data = relay.nn.pad(data, pad_width=pad_width, pad_value=128)
        new_attrs["padding"] = (0, 0)

        # Multiply 128 to adjust shift.
        adjust_shift = relay.multiply(adjust_shift, relay.const(128, "int32"))

    # Flags to remember if the expr is modified
    ic_modified = False
    oc_modified = False

    # Find the value of input and output channel.
    in_channel = -1
    out_channel = -1
    if attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWIO":
        in_channel = data_tensor.shape[3].value
        out_channel = kernel_tensor.shape[3].value
    elif attrs["data_layout"] == "NCHW" and attrs["kernel_layout"] == "OIHW":
        in_channel = data_tensor.shape[1].value
        out_channel = kernel_tensor.shape[0].value
    else:
        return None

    if in_channel % in_channel_vector_length != 0:
        new_in_channel = ((in_channel + in_channel_vector_length) //
                          in_channel_vector_length) * in_channel_vector_length
        diff = new_in_channel - in_channel
        if attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWIO":
            data = relay.nn.pad(data,
                                pad_width=((0, 0), (0, 0), (0, 0), (0, diff)))
            kernel = relay.nn.pad(kernel,
                                  pad_width=((0, 0), (0, 0), (0, diff), (0,
                                                                         0)))
            ic_modified = True
        elif attrs["data_layout"] == "NCHW" and attrs[
                "kernel_layout"] == "OIHW":
            pad_width = ((0, 0), (0, diff), (0, 0), (0, 0))
            data = relay.nn.pad(data, pad_width=pad_width)
            kernel = relay.nn.pad(kernel, pad_width=pad_width)
            ic_modified = True
        else:
            return None

    new_out_channel = out_channel
    if out_channel % out_channel_vector_length != 0:
        new_out_channel = (
            (out_channel + out_channel_vector_length) //
            out_channel_vector_length) * out_channel_vector_length
        diff = new_out_channel - out_channel
        if attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWIO":
            kernel = relay.nn.pad(kernel,
                                  pad_width=((0, 0), (0, 0), (0, 0), (0,
                                                                      diff)))
            oc_modified = True
        elif attrs["data_layout"] == "NCHW" and attrs[
                "kernel_layout"] == "OIHW":
            kernel = relay.nn.pad(kernel,
                                  pad_width=((0, diff), (0, 0), (0, 0), (0,
                                                                         0)))
            oc_modified = True
        else:
            return None

    if oc_modified:
        new_attrs["channels"] = new_out_channel
        out = relay.nn.conv2d(data, kernel, **new_attrs)
        original_out_shape = [x.value for x in output_tensor.shape]
        out = relay.strided_slice(out,
                                  begin=[0, 0, 0, 0],
                                  end=original_out_shape)
    else:
        out = relay.nn.conv2d(data, kernel, **new_attrs)

    if data_tensor.dtype != data_dtype:
        out = after_shift(out, adjust_shift)

    return out
示例#13
0
文件: _nn.py 项目: chenghanpeng/tvm
def legalize_conv2d_backward_weight(attrs, inputs, types):
    """Legalize conv2d_backward_weight op.

    Parameters
    ----------
    attrs : tvm.ir.Attrs
        Attributes of current op
    inputs : list of tvm.relay.Expr
        The args of the Relay expr to be legalized
    types : list of types
        List of input and output types

    Returns
    -------
    result : tvm.relay.Expr
        The legalized expr
    """
    grad, data = inputs
    data_shape = get_const_tuple(data.checked_type.shape)
    weight_shape = get_const_tuple(types[2].shape)
    _, out_channel, grad_h, grad_w = get_const_tuple(grad.checked_type.shape)
    batch, in_channel, in_h, in_w = data_shape
    _, _, filter_h, filter_w = weight_shape
    fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(
        get_const_tuple(attrs.padding), (filter_h, filter_w))
    stride_h, stride_w = get_const_tuple(attrs.strides)
    dilation_h, dilation_w = get_const_tuple(attrs.dilation)

    grad = relay.tile(grad, [1, in_channel // attrs.groups, 1, 1])
    grad = relay.reshape(grad,
                         [-1, 1, 0, 0])  # batch * oc * ic // groups, 1, oh, ow
    data = relay.reshape(data, [1, -1, 0, 0])  # 1, batch * ic, ih, iw

    backward_weight = relay.nn.conv2d(
        data,
        grad,
        strides=attrs.dilation,
        padding=attrs.padding,
        dilation=attrs.strides,
        groups=in_channel * batch,
        out_dtype=attrs.out_dtype,
    )

    # infer shape of backward_weight
    padded_weight_grad_h = (in_h - (grad_h - 1) * stride_h - 1 + fpad_top +
                            fpad_bottom) // dilation_h + 1
    padded_weight_grad_w = (in_w - (grad_w - 1) * stride_w - 1 + fpad_left +
                            fpad_right) // dilation_w + 1

    backward_weight = relay.reshape(
        backward_weight,
        [
            batch,
            in_channel // attrs.groups,
            out_channel,
            padded_weight_grad_h,
            padded_weight_grad_w,
        ],
    )
    backward_weight = relay.sum(backward_weight, axis=0)
    backward_weight = relay.transpose(backward_weight, [1, 0, 2, 3])

    assert padded_weight_grad_h >= filter_h
    assert padded_weight_grad_w >= filter_w

    if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
        backward_weight = relay.strided_slice(
            backward_weight,
            begin=[0, 0, 0, 0],
            end=[out_channel, in_channel // attrs.groups, filter_h, filter_w],
        )

    return backward_weight