def conv_bwd(N, CI, HI, WI, CO, HO, WO, KSIZE, stride, padding, dtype): strides = (stride, stride) shape_data = (N, CI, HI, WI) shape_weight = (CO, CI, KSIZE, KSIZE) shape_grad_output = (N, CO, HO, WO) # given tensor data = te.placeholder(shape_data, name="data", dtype=dtype) weight = te.placeholder(shape_weight, name="weight", dtype=dtype) grad_output = te.placeholder(shape_grad_output, name="grad_output", dtype=dtype) # grad_data out_h = (HO - 1) * strides[0] - 2 * padding + KSIZE out_w = (WO - 1) * strides[1] - 2 * padding + KSIZE output_padding = (HI - out_h, WI - out_w) grad_data = topi.nn.conv2d_transpose_nchw(grad_output, weight, strides, padding, dtype, output_padding) # grad_weight dilation_h, dilation_w = (1, 1) batch, in_channel, in_h, in_w = shape_data out_channel, _, filter_h, filter_w = shape_weight grad_output_tmp = topi.tile(grad_output, [1, in_channel, 1, 1]) grad_output_tmp = topi.reshape( grad_output_tmp, [batch * in_channel * out_channel, 1, HO, WO]) data_tmp = topi.reshape(data, [1, in_channel * batch, HI, WI]) grad_weight = topi.nn.group_conv2d_nchw(data_tmp, grad_output_tmp, stride=(dilation_h, dilation_w), padding=padding, dilation=strides, groups=in_channel * batch, out_dtype=dtype) # infer shape of grad_weight _, _, grad_h, grad_w = shape_grad_output fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple( padding, (filter_h, filter_w)) padded_weight_grad_h = (in_h - (grad_h - 1) * strides[0] - 1 + fpad_top + fpad_bottom) // dilation_h + 1 padded_weight_grad_w = (in_w - (grad_w - 1) * strides[1] - 1 + fpad_left + fpad_right) // dilation_w + 1 grad_weight = topi.reshape(grad_weight, [ batch, in_channel, out_channel, padded_weight_grad_h, padded_weight_grad_w ]) grad_weight = topi.sum(grad_weight, axis=0) grad_weight = topi.transpose(grad_weight, [1, 0, 2, 3]) if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w: grad_weight = topi.strided_slice( grad_weight, begin=[0, 0, 0, 0], end=[out_channel, in_channel, filter_h, filter_w]) return [data, weight, grad_output, grad_data, grad_weight] return [data, weight, grad_output, grad_data, grad_weight]
def verify_transpose(in_shape, axes): A = te.placeholder(shape=in_shape, name="A") B = topi.transpose(A, axes) def check_device(device, ctx): print("Running on target: %s" % device) with tvm.target.Target(device): s = tvm.topi.testing.get_injective_schedule(device)(B) foo = tvm.build(s, [A, B], device, name="transpose") data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype) out_npy = data_npy.transpose(axes) data_nd = tvm.nd.array(data_npy, ctx) out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=B.dtype) foo(data_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) for device, ctx in tvm.testing.enabled_targets(): check_device(device, ctx)
def verify_transpose(in_shape, axes): A = te.placeholder(shape=in_shape, name="A") B = topi.transpose(A, axes) def check_device(device): ctx = tvm.context(device, 0) if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) with tvm.target.create(device): s = tvm.topi.testing.get_injective_schedule(device)(B) foo = tvm.build(s, [A, B], device, name="transpose") data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype) out_npy = data_npy.transpose(axes) data_nd = tvm.nd.array(data_npy, ctx) out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=B.dtype) foo(data_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) for device in get_all_backend(): check_device(device)