def conv_backward_im2col(dout, cache):
    """
    A fast implementation of the backward pass for a convolutional layer
    based on im2col and col2im.
    """
    x, w, b, conv_param, x_cols = cache
    stride, pad = conv_param["stride"], conv_param["pad"]

    db = np.sum(dout, axis=(0, 2, 3))

    num_filters, _, filter_height, filter_width = w.shape
    dout_reshaped = dout.transpose(1, 2, 3, 0).reshape(num_filters, -1)
    dw = dout_reshaped.dot(x_cols.T).reshape(w.shape)

    dx_cols = w.reshape(num_filters, -1).T.dot(dout_reshaped)
    # dx = col2im_indices(dx_cols, x.shape, filter_height, filter_width, pad, stride)
    dx = col2im_cython(
        dx_cols,
        x.shape[0],
        x.shape[1],
        x.shape[2],
        x.shape[3],
        filter_height,
        filter_width,
        pad,
        stride,
    )

    return dx, dw, db
Example #2
0
def conv_backward_im2col(dout, cache):
  """
  A fast implementation of the backward pass for a convolutional layer
  based on im2col and col2im.
  """
  x, w, conv_param, x_cols = cache
  stride, pad = conv_param['stride'], conv_param['pad']

  num_filters, _, filter_height, filter_width = w.shape
  dout_reshaped = dout.transpose(1, 2, 3, 0).reshape(num_filters, -1)
  dw = dout_reshaped.dot(x_cols.T).reshape(w.shape)

  dx_cols = w.reshape(num_filters, -1).T.dot(dout_reshaped)
  # dx = col2im_indices(dx_cols, x.shape, filter_height, filter_width, pad, stride)
  dx = col2im_cython(dx_cols, x.shape[0], x.shape[1], x.shape[2], x.shape[3],
                     filter_height, filter_width, pad, stride)

  return dx, dw