Beispiel #1
0
    def backward(self, dout):

        x = self.cache_input
        w = self.weights
        b = self.bias
        x_cols = self.x_cols
        stride = self.S_filter
        pad = self.padding_input

        N, C, H, W = x.shape
        F, _, HH, WW = w.shape
        _, _, out_h, out_w = dout.shape

        db = np.sum(dout, axis=(0, 2, 3))

        dout_reshaped = dout.transpose(1, 0, 2, 3).reshape(F, -1)
        dw = dout_reshaped.dot(x_cols.T).reshape(w.shape)

        dx_cols = w.reshape(F, -1).T.dot(dout_reshaped)
        dx_cols.shape = (C, HH, WW, N, out_h, out_w)
        dx = col2im_6d_cython(dx_cols, N, C, H, W, HH, WW, pad, stride)

        self.dw = dw
        self.db = db

        self.weights = self.Weight_opt.update(self.weights, self.dw)
        self.bias = self.Bias_opt.update(self.bias, self.db)

        return dx
Beispiel #2
0
def conv_backward_strides(dout, cache):
    x, w, b, conv_param, x_cols = cache
    stride, pad = conv_param['stride'], conv_param['pad']

    N, C, H, W = x.shape
    F, _, HH, WW = w.shape
    _, _, out_h, out_w = dout.shape

    db = np.sum(dout, axis=(0, 2, 3))

    dout_reshaped = dout.transpose(1, 0, 2, 3).reshape(F, -1)
    dw = dout_reshaped.dot(x_cols.T).reshape(w.shape)

    dx_cols = w.reshape(F, -1).T.dot(dout_reshaped)
    dx_cols.shape = (C, HH, WW, N, out_h, out_w)
    dx = col2im_6d_cython(dx_cols, N, C, H, W, HH, WW, pad, stride)

    return dx, dw, db
Beispiel #3
0
def conv_backward_strides(dout, cache):
  x, w, b, conv_param, x_cols = cache
  stride, pad = conv_param['stride'], conv_param['pad']

  N, C, H, W = x.shape
  F, _, HH, WW = w.shape
  _, _, out_h, out_w = dout.shape

  db = np.sum(dout, axis=(0, 2, 3))

  dout_reshaped = dout.transpose(1, 0, 2, 3).reshape(F, -1)
  dw = dout_reshaped.dot(x_cols.T).reshape(w.shape)

  dx_cols = w.reshape(F, -1).T.dot(dout_reshaped)
  dx_cols.shape = (C, HH, WW, N, out_h, out_w)
  dx = col2im_6d_cython(dx_cols, N, C, H, W, HH, WW, pad, stride)

  return dx, dw, db
Beispiel #4
0
    def backward(self, dout):
        x, w, b, stride, pad, x_cols = self.cache

        N, C, H, W = x.shape
        F, _, HH, WW = w.shape
        _, _, out_h, out_w = dout.shape

        db = np.sum(dout, axis=(0, 2, 3))

        dout_reshaped = dout.transpose(1, 0, 2, 3).reshape(F, -1)
        dw = dout_reshaped.dot(x_cols.T).reshape(w.shape)

        dx_cols = w.reshape(F, -1).T.dot(dout_reshaped)
        dx_cols.shape = (C, HH, WW, N, out_h, out_w)
        dx = col2im_6d_cython(dx_cols, N, C, H, W, HH, WW, pad, stride)

        self.dw = dw
        self.db = db

        return dx