Esempio n. 1
0
    def backward(self, dout):
        FN, C, FH, FW = self.W.shape
        dout = dout.transpose(0,2,3,1).reshape(-1, FN)

        self.db = np.sum(dout, axis=0)
        self.dW = np.dot(self.col.T, dout) #전치행렬을 이용하여 곱이 될 수 있도록 형태를 맞춰준다.
        self.dW = self.dW.transpose(1,0).reshape(FN,C,FH,FW) #2차원 형태를 다시 4차원으로 바꿔준다.

        dcol = np.dot(dout, self.col_W.T)
        dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)

        return dx
Esempio n. 2
0
 def backward(self, dout):
     dout = dout.transpose(0, 2, 3, 1)
     
     pool_size = self.pool_h * self.pool_w
     dmax = np.zeros((dout.size, pool_size))
     dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()
     dmax = dmax.reshape(dout.shape + (pool_size,)) 
     
     dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
     dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)
     
     return dx
Esempio n. 3
0
    def backward(self, dout):
        FN, C, FH, FW = self.W.shape
        dout = dout.transpose(0, 2, 3, 1).reshape(-1, FN)

        self.db = np.sum(dout, axis=0)
        self.dW = np.dot(self.col.T, dout)
        self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)

        dcol = np.dot(dout, self.col_W.T)
        dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)

        return dx
Esempio n. 4
0
    def backward(self, dout):
        fn, c, fh, fw = self.w.shape
        dout = dout.transpose(0, 2, 3, 1).reshape(-1, fn)

        self.db = np.sum(dout, axis=0)
        self.dw = np.dot(self.col.T, dout).transpose(1,
                                                     0).reshape(fn, c, fh, fw)

        dcol = np.dot(dout, self.col_w.T)
        dx = col2im(dcol, self.x.shape, fh, fw, self.stride, self.pad)

        return dx
Esempio n. 5
0
    def backward(self, dout):
        dout = dout.transpose(0, 2, 3, 1)

        pool_size = self.pool_h * self.pool_w
        dmax = np.zeros((dout.size, pool_size))
        dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()
        dmax = dmax.reshape(dout.shape + (pool_size,))

        dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
        dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)

        return dx
Esempio n. 6
0
    def backward(self, dout):
        FN, C, FH, FW = self.W.shape
        dout = dout.transpose(0,2,3,1).reshape(-1, FN)

        self.db = np.sum(dout, axis=0)
        self.dW = np.dot(self.col.T, dout)
        self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)

        dcol = np.dot(dout, self.col_W.T)
        dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)

        return dx
Esempio n. 7
0
    def backward(self, dout):
        kernel_n, kernel_c, kernel_h, kernel_w = self.W.shape
        dout = dout.transpose((0, 2, 3, 1)).reshape(-1, kernel_n)

        self.db = np.sum(dout, axis=0)
        self.dW = np.dot(self.col.T, dout)
        self.dW = self.dW.transpose((1, 0)).reshape(kernel_n, kernel_c,
                                                    kernel_h, kernel_w)

        dcol = np.dot(dout, self.col_W.T)
        dx = col2im(dcol, self.x.shape, kernel_h, kernel_w, self.stride,
                    self.pad)

        return dx
def test_im2col():
    """
    x1 = np.random.rand(1, 3, 7, 7)
    col1 = im2col(x1, 5, 5, stride=1, pad=0)
    print(f'the shape of result is {col1.shape}')

    x1 = np.random.rand(10, 3, 7, 7)
    col1 = im2col(x1, 5, 5, stride=1, pad=0)
    print(f'the shape of result is {col1.shape}')
    """
    col = np.arange(90*75).reshape(90, 75)
    r1 = col2im(col, (10, 3, 7, 7), 5, 5, stride=1, pad=0)
    r2 = book_util.col2im(col, (10, 3, 7, 7), 5, 5, stride=1, pad=0)
    print(f"test result of test_im2col: {(r1==r2).all()}")
Esempio n. 9
0
    def backward(self, dout):
        dout = dout.transpose(0, 2, 3, 1)

        pool_size = self.pool_h * self.pool_w
        # just like Relu, firstly dx = dout and secondly set all non-max elements to 0
        dmax = np.zeros((dout.size, pool_size))
        dmax[np.arange(self.arg_max.size),
             self.arg_max.flatten()] = dout.flatten()
        dmax = dmax.reshape(dout.shape + (pool_size, ))

        dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
        dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride,
                    self.pad)

        return dx
    def backward(self, grad):
        #if grad.ndim == 2: #2차원인 경우 FCNN 계층에서 넘어왔다는 것이고, [N, out_H*out_W*filters] 의 shape를 가질것.
        #	grad = grad.reshape(self.x_shape[0], *self.out_shape, -1) # [N, out_H, out_W, filters]

        N, out_H, out_W, filters = grad.shape
        FH, FW = self.kernel_size

        grad = grad.reshape(N * out_H * out_W * filters,
                            1)  # [N*out_H*out_W*filters, 1]
        col = self.max_mask * grad  # [N*out_H*out_W*filters, FH*FW] <= [N*out_H*out_W*filters, FH*FW] * broadcast([N*out_H*out_W*filters, 1])
        col = col.reshape(N * out_H * out_W,
                          FH * FW * filters)  # [N*out_H*out_W, FH*FW*filters]

        x = util.col2im(col, self.x_shape, self.kernel_size, self.strides,
                        self.pad, self.out_shape)
        return x  # [*self.x_shape] == forward때 입력되었던 x의 shape [N, H, W, C]
Esempio n. 11
0
    def backward(self, dout):
        dout = dout.transpose(0, 2, 3, 1)

        pool_size = self.pool_h * self.pool_w
        dmax = np.zeros((dout.size, pool_size))
        # poolsize = window의 크기?
        # 큰 것이 작아져서 나갔기 때문에, 작은것이 큰것으로 들어 올 때 -> 빈 공간들을 0으로 채워준다
        dmax[np.arange(self.arg_max.size),
             self.arg_max.flatten()] = dout.flatten()
        dmax = dmax.reshape(dout.shape + (pool_size, ))

        dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
        dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride,
                    self.pad)

        return dx  # 들어온 값 그대로 아니면 0으로 리턴하겠다
Esempio n. 12
0
    def backward(self, dout):
        # 为什么需要转置?
        dout = dout.transpose(0, 2, 3, 1)

        pool_size = self.pool_h * self.pool_w
        dmax = np.zeros((dout.size, pool_size))
        # flatten 将任意维张量展开成一维向量。
        dmax[np.arange(self.arg_max.size),
             self.arg_max.flatten()] = dout.flatten()
        # dmax还要变形吗?,这是元组的加法,类似于字符串的加法,这个变形好奇怪?
        dmax = dmax.reshape(dout.shape + (pool_size, ))  # 每个最大值和pool_siz对应。
        # 最终还是要将其变为大矩形,同一个通道应该排在后面。
        dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
        dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride,
                    self.pad)
        return dx
Esempio n. 13
0
    def backward(self, dout):
        FN, C, FH, FW = self.W.shape
        if self.reshape_2dim:
            # print(dout.shape)
            N, _, H, W = self.x.shape
            dout = dout.reshape(N, 1, H, W)
            dout = np.repeat(dout, FN, axis=1)
        dout = dout.transpose(0, 2, 3, 1).reshape(-1, FN)

        self.db = np.sum(dout, axis=0)
        self.dW = np.dot(self.col.T, dout)
        self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)

        dcol = np.dot(dout, self.col_W.T)
        dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)

        return dx
Esempio n. 14
0
    def backward(self, dout):
        FN, C, FH, FW = self.W.shape
        dout = dout.transpose(0, 2, 3, 1).reshape(-1, FN)
        # 선 transpose -> 후 reshape # forward와 반대로 function
        # transpose에서 축의 인덱스도 forward (0, 3, 1, 2)와 반대로 해준다  -> (0,2,3,1)
        # out에서 같은 모양으로 나온다

        self.db = np.sum(dout, axis=0)  # 독립변수인 b, 그냥 숫자 한개로써(?) 더해주면 된다
        self.dW = np.dot(self.col.T, dout)  #2차원 dot 연산
        # col: 4차원 데이터를 2차원으로 펼쳐준 것
        self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)
        # 2차원을 다시 4차원으로 # x -> (convolution layer) -> Y
        # X(input) and Y(output) must have the same dimensions

        dcol = np.dot(dout, self.col_W.T)
        dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)
        # col2im: 2d image to 4d

        return dx
Esempio n. 15
0
    def backward(self, dout):
        FN, C, FH, FW = self.W.shape
        N, C, H, W = dout.shape
        dout = dout.transpose(0, 2, 3, 1).reshape(-1, FN)

        # affine層と同様の逆伝播
        self.db = np.sum(dout, axis=0)
        self.dW = np.dot(self.col.T, dout)
        self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)

        dcol = np.dot(dout, self.col_W.T)
        dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)

        # transpose処理(縮小方向)
        H_ = int(H / 2)
        W_ = int(W / 2)
        dx_ = np.zeros((1, 1, H_, W_), dtype=np.float32)
        dx_ = dx[:, :, ::self.pad_stride, ::self.pad_stride]

        return dx_
    def backward(self, dout):
        # print("Convolution backward dout : ", dout)
        FN, C, FH, FW = self.W.shape

        # print(dout.shape)
        # (100, 30, 24, 24)

        # todo transpose 왜?
        # 필터 개수만큼씩으로 잘라 몇 개의 묶음으로 만든다.
        dout = dout.transpose(0, 2, 3, 1).reshape(-1, FN)

        self.db = np.sum(dout, axis=0)
        self.dW = np.dot(self.col.T,
                         dout)  # 입력 데이터에서 필터를 적용하는 영역별로 가로로 전개한 데이터와 dout값 내적
        self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)

        dcol = np.dot(dout, self.col_W.T)  # forward에서 변형시켜놓은 필터
        dx = util.col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)

        return dx
Esempio n. 17
0
    def backward(self, dout):
        """逆伝播

        Args:
            dout (numpy.ndarray): 右の層から伝わってくる微分値、形状は(N, C, OH, OW)。

        Returns:
            numpy.ndarray: 微分値(勾配)、形状は(N, C, H, W)。
        """
        # 右の層からの微分値を整形
        # (N, C, OH, OW) → (N, OH, OW, C)
        dout = dout.transpose(0, 2, 3, 1)

        # 結果の微分値用のcolを0で初期化
        # (N * OH * OW * C, PH * PW)
        pool_size = self.pool_h * self.pool_w
        dcol_x = np.zeros((dout.size, pool_size))

        # 順伝播時に最大値として採用された位置にだけ、doutの微分値(=doutまんま)をセット
        # 順伝播時に採用されなかった値の位置は初期化時の0のまま
        # (ReLUでxが0より大きい場合およびxが0以下の場合の処理と同じ)
        assert dout.size == self.arg_max.size, '順伝搬時のcol_xの行数と合わない'
        dcol_x[np.arange(self.arg_max.size), self.arg_max.flatten()] = \
            dout.flatten()

        # 結果の微分値の整形1
        # (N * OH * OW * C, PH * PW) → (N, OH, OW, C, PH * PW)
        dcol_x = dcol_x.reshape(dout.shape +
                                (pool_size, ))  # 最後の','は1要素のタプルを示す

        # 結果の微分値の整形2
        # (N, OH, OW, C, PH * PW) → (N * OH * OW, C * PH * PW)
        dcol_x = dcol_x.reshape(
            dcol_x.shape[0] * dcol_x.shape[1] * dcol_x.shape[2], -1)

        # 結果の微分値の整形3
        # (N * OH * OW, C * PH * PW) → (N, C, H, W)
        dx = col2im(dcol_x, self.x.shape, self.pool_h, self.pool_w,
                    self.stride, self.pad)

        return dx
Esempio n. 18
0
    def backward(self, dout):
        FN, C, FH, FW = self.W.shape
        # np.shape(dout) = (N, FN, out_h, out_w)
        # transpose(0,2,3,1) = (N, out_h, out_w, FN)
        # reshape(-1, FN) = (N*out_h*out_w, FN)
        dout = dout.transpose(0, 2, 3, 1).reshape(-1, FN)

        self.db = np.sum(dout, axis=0)
        # np.shape(col.T) = (C*FH*FW,N*out_h*out_w)
        # dW = (C*FH*FW,FN)
        self.dW = np.dot(self.col.T, dout)
        # transpose(1, 0) = (FN, C*FH*FW)
        # reshape(FN, C, FH, FW)
        self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)

        # np.shape(col_w.T) = (FN, C*FH*FW)
        # dcol = (N*out_h*out_w, C*FH*FW)
        dcol = np.dot(dout, self.col_W.T)
        # x.shape = N, C, H, W
        dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)

        return dx
Esempio n. 19
0
    def backward(self, dout):
        dout = dout.transpose(0, 2, 3, 1)

        pool_size = self.pool_h * self.pool_w
        dmax = np.zeros((dout.size, pool_size))
        dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()
        dmax = dmax.reshape(dout.shape + (pool_size,))

        dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
        dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)

        return dx

##################self################
# class Pooling:
#     def __init__(self,pool_h,pool_w,stride=1,pad=0):
#         self.pool_h = pool_h
#         self.pool_w = pool_w
#         self.stride = stride
#         self.pad = pad
#
#     def forward(self,x):
#         N,C,H,W = x.shape
#         out_h = int(1 + (H - self.pool_h) / self.stride)
#         out_w = int(1 + (W - self.pool_w) / self.stride)
#
#         #展開(1)
#         col = im2col(x,self.pool_h,self.pool_w,self.stride,self.pad)
#         col = col.reshape(-1,self.pool_h * self.pool_w)
#
#         #最大値(2)
#         out = np.max(col,axis=1)
#         #整形(3)
#         out = out.reshape(N, out_h, out_w, C).transpose(0,3,1,2)
#
#         return out
##################self##################
Esempio n. 20
0
    def backward(self, dout):
        """
        逆伝播計算
        マックスプーリングでは、順伝播計算時に最大値となった場所だけに勾配を伝える
        順伝播計算時に最大値となった場所は、self.arg_maxに保持されている        
        dout : 出力層側の勾配
        return : 入力層側へ伝える勾配
        """

        # doutのチャンネル数軸を4番目に移動させる
        dout = dout.transpose(0, 2, 3, 1)

        # プーリング適応領域の要素数(プーリング適応領域の高さ × プーリング適応領域の幅)
        pool_size = self.pool_h * self.pool_w

        # 勾配を入れる配列を初期化する
        # dcolの配列形状 : (doutの全要素数, プーリング適応領域の要素数)
        # doutの全要素数は、dout.size で取得できる
        dcol = np.zeros((dout.size, pool_size))

        # 順伝播計算時に最大値となった場所に、doutを配置する
        # dout.flatten()はdoutを1次元配列に変換している
        dcol[np.arange(self.arg_max.size),
             self.arg_max.flatten()] = dout.flatten()

        # 勾配を4次元配列(データ数, チャンネル数, 高さ, 幅)に変換する
        dx = col2im(dcol,
                    self.x.shape,
                    self.pool_h,
                    self.pool_w,
                    self.stride,
                    self.pad,
                    is_backward=True)

        self.dcol = dcol  # 結果を確認するために保持しておく

        return dx
Esempio n. 21
0
    def backward(self, dout):
        # フィルターのサイズを取得
        FN, C, FH, FW = self.W.shape

        # 逆伝播の入力データを順伝播のtranspose前の行列形式に変換
        # forward の最終出力が out.reshape.transpose なので
        # backwardの入力を dout.transpose.reshape にする
        dout = dout.transpose(0, 2, 3, 1).reshape(-1, FN)
        """
        順伝播でのバイアス加算は、それぞれのデータ(1個目のデータ、2個目のデータ、・・・)に対して加算が行われる。
        そのため、逆伝播の際には、それぞれのデータの逆伝播の値がバイアスの要素に集約される必要がある。
        """
        # db = dL/dB = dL/dY = dout となるので、1x1xFNの形にする
        self.db = np.sum(dout, axis=0)

        # dL/dW = X.T * dL/dY となるので、
        self.dW = np.dot(self.col.T, dout)
        self.dW = (self.dW.T).reshape(FN, C, FH, FW)

        # dL/dx = dL/dY * W.T
        dcol = np.dot(dout, self.col_W.T)
        dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)

        return dx
Esempio n. 22
0
    def backward(self, dout):
        """
        逆伝播計算
        Affineレイヤと同様の考え方で、逆伝播させる
        dout : 出力層側の勾配
        return : 入力層側へ伝える勾配
        """
        FN, C, FH, FW = self.W.shape

        # doutのチャンネル数軸を4番目に移動させ、2次元配列に変換する
        dout = dout.transpose(0, 2, 3, 1).reshape(-1, FN)

        # バイアスbはデータ数方向に総和をとる
        self.db = np.sum(dout, axis=0)

        # 重みWは、入力である行列colと行列doutの積になる
        self.dW = np.dot(self.col.T, dout)

        # (フィルター数, チャンネル数, フィルター高さ、フィルター幅)の配列形状に戻す
        self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)

        # 入力側の勾配は、doutにフィルターの重みを掛けて求める
        dcol = np.dot(dout, self.col_W.T)

        # 勾配を4次元配列(データ数, チャンネル数, 高さ, 幅)に変換する
        dx = col2im(dcol,
                    self.x.shape,
                    FH,
                    FW,
                    self.stride,
                    self.pad,
                    is_backward=True)

        self.dcol = dcol  # 結果を確認するために保持しておく

        return dx
    def backward(self, dout):
        """
        其实反向传播就是按照前向传播一行行倒推就行了。
        目的就是求出dx,传给下游。如果本层有参数需要训练,
        那么还要将该参数的梯度保存起来
        :param dout:
        :return:
        """
        # 将上游传来的梯度转置成 N * out_h * out_w * C,与im2col的顺序一致
        # 与forward逆向
        dout = dout.transpose((0, 2, 3, 1))

        # forward中求np.max,在backward中就反向求dmax
        pool_size = self.pool_h * self.pool_w  # 一次池化的元素个数
        # 前向传播时展开后的二维矩阵形状是(-1, pool_size),
        # 在反向传播中也要保持形状,把dout拉直成一维数组后增加pool_size列
        dmax = np.zeros((dout.size, pool_size))
        # 利用前向传播时记录的中间结果(每行最大值的索引),
        # 将上面生成的二维矩阵中每一行最大值的位置填入dout对应的值
        # dmax其他元素都为0,但是最大值的位置就是dout对应行的最大值。

        # 也就是说,只有最大值所在的位置梯度才能传递下去
        dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] =\
            dout.flatten()
        # dmax调整为 N * H * W * C * pool_size。
        # 由于forward中将col.reshape(-1, pool_size),在backward中也要保持形状
        dmax = dmax.reshape(dout.shape + (pool_size, ))

        # 因为col的形状是: N * out_h * out_w, -1
        # 所以dcol也要调整为一样的形状
        dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
        # 由于col = im2col(x),因此dx = col2im(col)
        dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride,
                    self.pad)

        return dx
Esempio n. 24
0
out_h = int(1 + (H_ - pool_h) / stride)
out_w = int(1 + (W_ - pool_w) / stride)

col_old = im2col(x, pool_h, pool_w, stride, pad) # 출력해보세요
col = col_old.reshape(-1, pool_h*pool_w) # 위의 결과와 비교해보세요

arg_max = np.argmax(col, axis=1) # max의 위치기억 (np.argmax)
out = np.max(col, axis=1) # 출력계산
out_final = out.reshape(N_, out_h, out_w, C_).transpose(0, 3, 1, 2)
print("out_final :\n", out_final)

######################################################################
#%% backward
dout = np.ones_like(out_final)
dout = dout.transpose(0, 2, 3, 1)
        
pool_size = pool_h * pool_w

# 결과를 저장하기 위한 공간 (pooling size만큼 크기를 넓혀 줍니다.)
dmax = np.zeros((dout.size, pool_size))

# 기억해 두었던 위치에 값을 채워줍니다.
dmax[np.arange(arg_max.size), arg_max.flatten()] = dout.flatten()

# 원래의 모양대로 복원합니다.
dmax = dmax.reshape(dout.shape + (pool_size,)) 

dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
dx = col2im(dcol, x.shape, pool_h, pool_w, stride, pad)
print("dx : \n", dx)
Esempio n. 25
0
    def backward(self, dout):
        if self.enable_bp_gradient_quantization:
            if self.qat_scheme == KMQATScheme.Dequantization:

                FN, C, FH, FW = self.W.shape
                dout = dout.transpose(0, 2, 3, 1).reshape(-1, FN)

                db = np.sum(dout, axis=0)
                db_int8 = self.quantizer.fake_quantize(db)
                db_q = self.quantizer.fake_dequantize(db_int8)
                self.db = db_q

                col_int8 = self.quantizer.bp_col_quantize(self.col)
                dout_int8 = self.quantizer.bp_dout_quantize(dout)
                dW_int8 = np.dot(col_int8.T, dout_int8)
                dW_q = self.quantizer.bp_dW_dequantize(dW_int8)
                dW_q = dW_q.transpose(1, 0).reshape(FN, C, FH, FW)
                self.dW = dW_q

                dcol_W_int8 = self.quantizer.bp_col_W_quantize(self.col_W)
                dcol_int8 = np.dot(dout_int8, dcol_W_int8.T)
                dcol_q = self.quantizer.bp_dcol_dequantize(dcol_int8)
                dx = col2im(dcol_q, self.x.shape, FH, FW, self.stride,
                            self.pad)

                return dx

            elif self.qat_scheme == KMQATScheme.ErrorCompensation:
                FN, C, FH, FW = self.W.shape
                dout = dout.transpose(0, 2, 3, 1).reshape(-1, FN)

                db = np.sum(dout, axis=0)
                db_int8 = self.quantizer.fake_quantize(db)
                db_q = self.quantizer.fake_dequantize(db_int8)
                self.db = db_q

                col_int8 = self.quantizer.bp_col_quantize(self.col)
                dout_int8 = self.quantizer.bp_dout_quantize(dout)
                dW_int8 = np.dot(col_int8.T, dout_int8)
                compensation = self.quantizer.bp_compensation(
                    self.col, col_int8, dout_int8)
                dW_q = self.quantizer.bp_dW_dequantize(dW_int8 + compensation)
                dW_q = dW_q.transpose(1, 0).reshape(FN, C, FH, FW)
                self.dW = dW_q

                dcol_W_int8 = self.quantizer.bp_col_W_quantize(self.col_W)
                dcol_int8 = np.dot(dout_int8, dcol_W_int8.T)
                compensation_dx = self.quantizer.bp_compensation(self.col_W,
                                                                 dcol_W_int8,
                                                                 dout_int8,
                                                                 com_dx=True)
                dcol_q = self.quantizer.bp_dcol_dequantize(dcol_int8 +
                                                           compensation_dx)
                dx = col2im(dcol_q, self.x.shape, FH, FW, self.stride,
                            self.pad)

                return dx

            elif self.qat_scheme == KMQATScheme.LossAwareCompensation:
                FN, C, FH, FW = self.W.shape
                dout = dout.transpose(0, 2, 3, 1).reshape(-1, FN)

                db = np.sum(dout, axis=0)
                db_int8 = self.quantizer.fake_quantize(db)
                db_q = self.quantizer.fake_dequantize(db_int8)
                self.db = db_q

                if self.enable_gradient_clipping:
                    col_clipped = self.parametrized_range_clipping(self.col)
                    dout_clipped = self.parametrized_range_clipping(dout)

                    col_int8 = self.quantizer.bp_col_quantize(col_clipped)
                    dout_int8 = self.quantizer.bp_dout_quantize(dout_clipped)
                else:
                    col_int8 = self.quantizer.bp_col_quantize(self.col)
                    dout_int8 = self.quantizer.bp_dout_quantize(dout)

                dW_int8 = np.dot(col_int8.T, dout_int8)
                dW_q = self.quantizer.bp_dW_dequantize(dW_int8)
                dW_q = dW_q.transpose(1, 0).reshape(FN, C, FH, FW)
                self.dW = dW_q

                if self.enable_gradient_clipping:
                    col_W_clipped = self.parametrized_range_clipping(
                        self.col_W)
                    dcol_W_int8 = self.quantizer.bp_col_W_quantize(
                        col_W_clipped)
                else:
                    dcol_W_int8 = self.quantizer.bp_col_W_quantize(self.col_W)

                dcol_int8 = np.dot(dout_int8, dcol_W_int8.T)
                dcol_q = self.quantizer.bp_dcol_dequantize(dcol_int8)
                dx = col2im(dcol_q, self.x.shape, FH, FW, self.stride,
                            self.pad)

                return dx

            else:
                print("====== Exception: Undefined QAT Scheme in {} {}".format(
                    self.layer_id,
                    sys._getframe().f_code.co_name))
                pass

        else:
            if self.qat_scheme == KMQATScheme.LossAwareCompensation:
                FN, C, FH, FW = self.W.shape
                dout = dout.transpose(0, 2, 3, 1).reshape(-1, FN)

                self.db = np.sum(dout, axis=0)
                self.dW = np.dot(self.col.T, dout)
                self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)

                dcol = np.dot(dout, self.col_W.T)
                dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)

                return dx

            else:
                FN, C, FH, FW = self.W.shape
                dout = dout.transpose(0, 2, 3, 1).reshape(-1, FN)

                self.db = np.sum(dout, axis=0)
                self.dW = np.dot(self.col.T, dout)
                self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)

                dcol = np.dot(dout, self.col_W.T)
                dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)

                return dx
    def test_col2im_transforms(self):
        col = np.random.randn(3125, 16)
        x = col2im(col, (5, 1, 28, 28), 4, 4, stride=1, pad=0)

        self.assertSequenceEqual((5, 1, 28, 28), x.shape)