コード例 #1
0
def conv_bwd(N, CI, HI, WI, CO, HO, WO, KSIZE, stride, padding, dtype):
    strides = (stride, stride)
    shape_data = (N, CI, HI, WI)
    shape_weight = (CO, CI, KSIZE, KSIZE)
    shape_grad_output = (N, CO, HO, WO)
    # given tensor
    data = te.placeholder(shape_data, name="data", dtype=dtype)
    weight = te.placeholder(shape_weight, name="weight", dtype=dtype)
    grad_output = te.placeholder(shape_grad_output,
                                 name="grad_output",
                                 dtype=dtype)
    # grad_data
    out_h = (HO - 1) * strides[0] - 2 * padding + KSIZE
    out_w = (WO - 1) * strides[1] - 2 * padding + KSIZE
    output_padding = (HI - out_h, WI - out_w)
    grad_data = topi.nn.conv2d_transpose_nchw(grad_output, weight, strides,
                                              padding, dtype, output_padding)
    # grad_weight
    dilation_h, dilation_w = (1, 1)
    batch, in_channel, in_h, in_w = shape_data
    out_channel, _, filter_h, filter_w = shape_weight
    grad_output_tmp = topi.tile(grad_output, [1, in_channel, 1, 1])
    grad_output_tmp = topi.reshape(
        grad_output_tmp, [batch * in_channel * out_channel, 1, HO, WO])
    data_tmp = topi.reshape(data, [1, in_channel * batch, HI, WI])
    grad_weight = topi.nn.group_conv2d_nchw(data_tmp,
                                            grad_output_tmp,
                                            stride=(dilation_h, dilation_w),
                                            padding=padding,
                                            dilation=strides,
                                            groups=in_channel * batch,
                                            out_dtype=dtype)
    # infer shape of grad_weight
    _, _, grad_h, grad_w = shape_grad_output
    fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(
        padding, (filter_h, filter_w))
    padded_weight_grad_h = (in_h - (grad_h - 1) * strides[0] - 1 + fpad_top +
                            fpad_bottom) // dilation_h + 1
    padded_weight_grad_w = (in_w - (grad_w - 1) * strides[1] - 1 + fpad_left +
                            fpad_right) // dilation_w + 1
    grad_weight = topi.reshape(grad_weight, [
        batch, in_channel, out_channel, padded_weight_grad_h,
        padded_weight_grad_w
    ])
    grad_weight = topi.sum(grad_weight, axis=0)
    grad_weight = topi.transpose(grad_weight, [1, 0, 2, 3])

    if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
        grad_weight = topi.strided_slice(
            grad_weight,
            begin=[0, 0, 0, 0],
            end=[out_channel, in_channel, filter_h, filter_w])
        return [data, weight, grad_output, grad_data, grad_weight]

    return [data, weight, grad_output, grad_data, grad_weight]
コード例 #2
0
def verify_transpose(in_shape, axes):
    A = te.placeholder(shape=in_shape, name="A")
    B = topi.transpose(A, axes)

    def check_device(device, ctx):
        print("Running on target: %s" % device)
        with tvm.target.Target(device):
            s = tvm.topi.testing.get_injective_schedule(device)(B)
        foo = tvm.build(s, [A, B], device, name="transpose")
        data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype)
        out_npy = data_npy.transpose(axes)
        data_nd = tvm.nd.array(data_npy, ctx)
        out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=B.dtype)
        foo(data_nd, out_nd)
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)

    for device, ctx in tvm.testing.enabled_targets():
        check_device(device, ctx)
コード例 #3
0
def verify_transpose(in_shape, axes):
    A = te.placeholder(shape=in_shape, name="A")
    B = topi.transpose(A, axes)
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = tvm.topi.testing.get_injective_schedule(device)(B)
        foo = tvm.build(s, [A, B], device, name="transpose")
        data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype)
        out_npy = data_npy.transpose(axes)
        data_nd = tvm.nd.array(data_npy, ctx)
        out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=B.dtype)
        foo(data_nd, out_nd)
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)

    for device in get_all_backend():
        check_device(device)