def sparse_conv2d(x, w, blk_indices, strides, padding): """ Performs 2D convolution on a sparse feature map, given indices. Naive python implementation of sparse convolution using gather and scatter. :param x: [Tensor] [N, H, W, C]. Input activation tensor, dtype float32. :param w: [Tensor] [I, J, C, K]. Convolution kernel, dtype float32. :param blk_indices: [Tensor] [M, h, w, 3]. Block indices of rectangles. :param strides: [list] List of 4 int, convolution strides. :param padding: [string] `VALID` or `SAME`, padding method for sparse convolution. :return [Tensor] [N, H', W', C]. Convolution results. """ blk_shape = tf.shape(blk_indices) blk_indices_ = tf.reshape(blk_indices, [-1, 3]) ksize = tf.shape(w) # Calculate the block strides. bstrides = _calc_block_strides(blk_shape, ksize, strides) # Calculate the output size. x_shape = tf.shape(x) out_shape = calc_out_size_4d(x_shape, ksize, strides, padding) # Pad input. x_ = _pad_input(x, ksize, strides, padding, bsize=[1, blk_shape[1], blk_shape[2], 1], bstrides=bstrides) # Convolution when number of indices is larger than zero. def _conv_nonzero(): # Gather patches. p = tf.gather_nd(x_, blk_indices_) # Reshape patches. p = tf.reshape(p, [blk_shape[0], blk_shape[1], blk_shape[2], -1]) # Convolution on patches. q = tf.nn.conv2d(p, w, strides, 'VALID', use_cudnn_on_gpu=True) # Paste convolution results. q_shape = tf.shape(q) def _strides_gt_one(): # Calculate output indices when strides > 1. blk_indices_crop = tf.strided_slice(blk_indices, [0, 0, 0, 0], [ blk_shape[0], q_shape[1] * strides[1], q_shape[2] * strides[2], 3 ], strides) blk_indices_crop = blk_indices_crop // tf.stack( [1, strides[1], strides[2]]) return blk_indices_crop def _strides_one(): # Calculate otuput indices when strides = 1. return blk_indices[:, :q_shape[1], :q_shape[2], :] strides_gt_one = tf.logical_or(tf.greater(strides[1], 1), tf.greater(strides[2], 1)) blk_indices_crop = tf.cond(strides_gt_one, _strides_gt_one, _strides_one) y = tf.scatter_nd(blk_indices_crop, q, out_shape) return y return tf.cond(tf.equal(tf.size(blk_indices_), 0), lambda: tf.zeros(out_shape, dtype=x.dtype), _conv_nonzero)
def sparse_conv2d_matmul(x, w, blk_indices, strides, padding): """ Performs 2D convolution using matrix multiplication on a sparse feature map. Naive python implementation of sparse convolution using gather and scatter. :param x: [Tensor] [N, H, W, C]. Input activation tensor, dtype float32. :param w: [Tensor] [I, J, C, K]. Convolution kernel, dtype float32. :param blk_indices: [Tensor] [M, h, w, 3]. Block indices of rectangles. :param strides: [list] List of 4 int, convolution strides. :param padding: [string] `VALID` or `SAME`, padding method for sparse convolution. :return [Tensor] [N, H', W', C]. Convolution results. """ blk_indices_ = tf.reshape(blk_indices, [-1, 3]) blk_shape = tf.shape(blk_indices) ksize = tf.shape(w) # Calculate the block strides. bstrides = _calc_block_strides(blk_shape, ksize, strides) # Calculate the output size. x_shape = tf.shape(x) out_shape = calc_out_size_4d(x_shape, ksize, strides, padding) # Pad input. x_ = _pad_input(x, ksize, strides, padding, bsize=[1, blk_shape[1], blk_shape[2], 1], bstrides=bstrides) # In matrix multiplication mode, the block patch should be the same as the kernel size. assert_shape = tf.assert_equal( tf.stack([blk_shape[1], blk_shape[2]]), tf.stack([ksize[0], ksize[1]]), message= 'Expect blk_indices.shape[1] == w.shape[0] and blk_indices.shape[2] == w.shape[1].' ) # Currently we do not support strides > 1 in this matrix multiplication mode. Could be supported # in the future. assert_strides = tf.assert_equal(tf.cast( tf.stack([strides[1], strides[2]]), tf.int64), tf.constant([1, 1], dtype=tf.int64), message='Strides > 1 not supported.') # Convolution when number of indices is larger than zero. def _conv_nonzero(): # Gather patches. p = tf.gather_nd(x_, blk_indices_) p_ = tf.reshape(p, [-1, ksize[0] * ksize[1] * ksize[2]]) # Convolution on patches. w_ = tf.reshape(w, [ksize[0] * ksize[1] * ksize[2], -1]) q = tf.matmul(p_, w_) # Center locations. blk_indices_crop = blk_indices[:, 0, 0, :] # Project back to an image. y = tf.scatter_nd(blk_indices_crop, q, out_shape) return y with tf.control_dependencies([assert_shape, assert_strides]): return tf.cond(tf.equal(tf.size(blk_indices_), 0), lambda: tf.zeros(out_shape, dtype=x.dtype), _conv_nonzero)
def calc_block_params(in_size, bsize, ksize, strides, padding, static=True): """ Calculates block parameters for a single convolution layer. :param in_size: [list] List of 4 int. Size of the convolution input. :param bsize: [list] List of 4 int. Size of blocks, or downsample ratio. :param ksize: [list] List of 4 int. Sparse convolution kernel size. :param strides: [list] List of 4 int. Sparse convolution stride size. Currently only supports when, 1) (bsize[1] - ksize[0]) % strides[1] == 0 and, 2) (bsize[2] - ksize[1]) % strides[2] == 0 :param padding: [string] `VALID` or `SAME`, padding method for sparse convolution. :return [tuple] bsize: bsize_out: boffset: bcount: bstrides: """ assert ((bsize[1] - ksize[0]) % strides[1] == 0) assert ((bsize[2] - ksize[1]) % strides[2] == 0) bstrides = _calc_block_strides(bsize, ksize, strides) pad_h0, pad_h1, pad_w0, pad_w1 = calc_padding_4d(in_size, ksize, strides, padding) h = in_size[1] w = in_size[2] # Make padding divides blocks. pad_h1 += (-h + bsize[1]) % bstrides[1] pad_w1 += (-w + bsize[2]) % bstrides[2] boffset = [-pad_h0, -pad_w0] x_pad_shape = [ in_size[0], in_size[1] + pad_h0 + pad_h1, in_size[2] + pad_w0 + pad_w1, in_size[3] ] if static: out_shape = calc_out_size_4d_np(x_pad_shape, [bsize[1], bsize[2], 1, 1], bstrides, 'VALID') else: out_shape = calc_out_size_4d(x_pad_shape, [bsize[1], bsize[2], 1, 1], bstrides, 'VALID') bcount = [out_shape[1], out_shape[2]] bsize_out = calc_out_size_4d_np(bsize, ksize, strides, 'VALID') bsize = bsize[1:3] bstrides = bstrides[1:3] bsize_out = bsize_out[1:3] # print('h w', h, w) # print('bcount', bcount) # print('bsize', bsize) # print('bsize_out', bsize_out) # print('boffset', boffset) # print('bstrides', bstrides) # print(pad_h0, pad_w0, boffset) if static: assert (pad_h0 == -boffset[0]) assert (pad_w0 == -boffset[1]) for i, siz in zip([0, 1], [h, w]): # make sure last block is inside err_msg = 'Making sure last block is inside boffset {} bstrides {} bcount {} size {}'.format( boffset[i], bstrides[i], bcount[i], siz) assert (boffset[i] + bstrides[i] * (bcount[i] - 1) < siz), err_msg return BlockParams(bsize=bsize, bsize_out=bsize_out, boffset=boffset, bcount=bcount, bstrides=bstrides)