def test_2d_conv(): batch_size = 1 stride = 1 padding = 0 fh = 2 fw = 2 input_channel = 1 output_channel = 1 iw = 3 ih = 3 (output_height, output_width) = calculate_output_size(ih, iw, fh, fw, padding, stride) wb = ConvWeightsBias(output_channel, input_channel, fh, fw, InitialMethod.MSRA, OptimizerName.SGD, 0.1) wb.Initialize("test", "test", True) wb.W = np.array([3,2,1,0]).reshape(1,1,2,2) wb.B = np.array([0]) x = np.array(range(9)).reshape(1,1,3,3) output1 = jit_conv_4d(x, wb.W, wb.B, output_height, output_width, stride) print("input=\n", x) print("weights=\n", wb.W) print("output=\n", output1) col = img2col(x, 2, 2, 1, 0) w = wb.W.reshape(4, 1) output2 = np.dot(col, w) print("input=\n", col) print("weights=\n", w) print("output2=\n", output2)
def understand_4d_im2col(): batch_size = 2 stride = 1 padding = 0 fh = 2 fw = 2 input_channel = 3 output_channel = 2 iw = 3 ih = 3 (output_height, output_width) = calculate_output_size(ih, iw, fh, fw, padding, stride) wb = ConvWeightsBias(output_channel, input_channel, fh, fw, InitialMethod.MSRA, OptimizerName.SGD, 0.1) wb.Initialize("test", "test", True) wb.W = np.array(range(output_channel * input_channel * fh * fw)).reshape(output_channel, input_channel, fh, fw) wb.B = np.array([0]) x = np.array(range(input_channel * iw * ih * batch_size)).reshape(batch_size, input_channel, ih, iw) col = img2col(x, 2, 2, 1, 0) w = wb.W.reshape(output_channel, -1).T output = np.dot(col, w) print("x=\n", x) print("col_x=\n", col) print("weights=\n", wb.W) print("col_w=\n", w) print("output=\n", output) out2 = output.reshape(batch_size, output_height, output_width, -1) print("out2=\n", out2) out3 = np.transpose(out2, axes=(0, 3, 1, 2)) print("conv result=\n", out3)
def understand_4d_col2img_complex(): batch_size = 2 stride = 1 padding = 0 fh = 2 fw = 2 input_channel = 3 output_channel = 2 iw = 3 ih = 3 (output_height, output_width) = calculate_output_size(ih, iw, fh, fw, padding, stride) wb = ConvWeightsBias(output_channel, input_channel, fh, fw, InitialMethod.MSRA, OptimizerName.SGD, 0.1) wb.Initialize("test", "test", True) wb.W = np.array(range(output_channel * input_channel * fh * fw)).reshape( output_channel, input_channel, fh, fw) wb.B = np.array([0]) x = np.array(range(input_channel * iw * ih * batch_size)).reshape( batch_size, input_channel, ih, iw) print("x=\n", x) col_x = img2col(x, fh, fw, stride, padding) print("col_x=\n", col_x) print("w=\n", wb.W) col_w = wb.W.reshape(output_channel, -1).T print("col_w=\n", col_w) # backward delta_in = np.array( range(batch_size * output_channel * output_height * output_width)).reshape(batch_size, output_channel, output_height, output_width) print("delta_in=\n", delta_in) delta_in_2d = np.transpose(delta_in, axes=(0, 2, 3, 1)).reshape(-1, output_channel) print("delta_in_2d=\n", delta_in_2d) dB = np.sum(delta_in_2d, axis=0, keepdims=True).T / batch_size print("dB=\n", dB) dW = np.dot(col_x.T, delta_in_2d) / batch_size print("dW=\n", dW) dW = np.transpose(dW, axes=(1, 0)).reshape(output_channel, input_channel, fh, fw) print("dW=\n", dW) dcol = np.dot(delta_in_2d, col_w.T) print("dcol=\n", dcol) delta_out = col2img(dcol, x.shape, fh, fw, stride, padding, output_height, output_width) print("delta_out=\n", delta_out)