def conv2d_transpose_cudnn(x, w, stride, padding, out_dtype, output_padding=(0, 0)): """Compute conv2d_tranpose using cudnn dgrad kernel""" return cudnn.conv_backward_data(x, w, padding, stride, (1, 1), 1, 0, out_dtype, groups=1, output_padding=output_padding)
def conv2d_transpose_cudnn(x, w, stride, padding, out_dtype, output_padding=(0, 0), layout="NCHW"): """Compute conv2d_tranpose using cudnn dgrad kernel""" tensor_format = 0 if layout == "NCHW" else 1 return cudnn.conv_backward_data( x, w, padding, stride, (1, 1), 1, tensor_format, out_dtype, groups=1, output_padding=output_padding, )
def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0, tol=1e-5): batch = 3 in_channel = 4 out_channel = 16 filter_h, filter_w = 3, 3 pad_h, pad_w = 1, 1 stride_h, stride_w = 1, 1 height, width = 32, 32 if tensor_format == 0: xshape = [batch, in_channel, height, width] wshape = [out_channel, in_channel, filter_h, filter_w] oshape = xshape oshape[1] = out_channel ref_func = tvm.topi.testing.conv2d_transpose_nchw_python else: xshape = [batch, height, width, in_channel] wshape = [out_channel, filter_h, filter_w, in_channel] oshape = xshape oshape[3] = out_channel ref_func = lambda dy_np, w_np, strides, padding, out_pad: tvm.topi.testing.conv2d_transpose_nhwc_python( dy_np, np.transpose(w_np, [1, 2, 3, 0]), "HWOI", strides, padding, out_pad) dy_np = np.random.uniform(-1, 1, oshape).astype(data_dtype) w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype) if data_dtype == "float16": dx_np = ref_func( dy_np.astype("float32"), w_np.astype("float32"), (stride_h, stride_w), (pad_h, pad_w), (0, 0), ) dx_np = dx_np.astype("float16") else: dx_np = ref_func(dy_np, w_np, (stride_h, stride_w), (pad_h, pad_w), (0, 0)) dy = te.placeholder(oshape, name="dy", dtype=data_dtype) w = te.placeholder(wshape, name="dw", dtype=data_dtype) dx = cudnn.conv_backward_data( dy, w, [pad_h, pad_w], [stride_h, stride_w], [1, 1], conv_mode=1, tensor_format=tensor_format, conv_dtype=conv_dtype, groups=1, ) s = te.create_schedule(dx.op) dev = tvm.cuda(0) f = tvm.build(s, [dy, w, dx], "cuda --host=llvm", name="conv2d_backward_data") dy = tvm.nd.array(dy_np, dev) w = tvm.nd.array(w_np, dev) dx = tvm.nd.array(dx_np, dev) f(dy, w, dx) tvm.testing.assert_allclose(dx.numpy(), dx_np, atol=tol, rtol=tol)