Esempio n. 1
0
def sparse_conv2d(x, w, blk_indices, strides, padding):
    """
    Performs 2D convolution on a sparse feature map, given indices.
    Naive python implementation of sparse convolution using gather and scatter.

    :param x:           [Tensor]  [N, H, W, C]. Input activation tensor, dtype float32.
    :param w:           [Tensor]  [I, J, C, K]. Convolution kernel, dtype float32.
    :param blk_indices: [Tensor]  [M, h, w, 3]. Block indices of rectangles.
    :param strides:     [list]    List of 4 int, convolution strides.
    :param padding:     [string]  `VALID` or `SAME`, padding method for sparse convolution.

    :return             [Tensor]  [N, H', W', C]. Convolution results.
    """
    blk_shape = tf.shape(blk_indices)
    blk_indices_ = tf.reshape(blk_indices, [-1, 3])
    ksize = tf.shape(w)

    # Calculate the block strides.
    bstrides = _calc_block_strides(blk_shape, ksize, strides)

    # Calculate the output size.
    x_shape = tf.shape(x)
    out_shape = calc_out_size_4d(x_shape, ksize, strides, padding)

    # Pad input.
    x_ = _pad_input(x,
                    ksize,
                    strides,
                    padding,
                    bsize=[1, blk_shape[1], blk_shape[2], 1],
                    bstrides=bstrides)

    # Convolution when number of indices is larger than zero.
    def _conv_nonzero():
        # Gather patches.
        p = tf.gather_nd(x_, blk_indices_)

        # Reshape patches.
        p = tf.reshape(p, [blk_shape[0], blk_shape[1], blk_shape[2], -1])

        # Convolution on patches.
        q = tf.nn.conv2d(p, w, strides, 'VALID', use_cudnn_on_gpu=True)

        # Paste convolution results.
        q_shape = tf.shape(q)

        def _strides_gt_one():
            # Calculate output indices when strides > 1.
            blk_indices_crop = tf.strided_slice(blk_indices, [0, 0, 0, 0], [
                blk_shape[0], q_shape[1] * strides[1], q_shape[2] * strides[2],
                3
            ], strides)
            blk_indices_crop = blk_indices_crop // tf.stack(
                [1, strides[1], strides[2]])
            return blk_indices_crop

        def _strides_one():
            # Calculate otuput indices when strides = 1.
            return blk_indices[:, :q_shape[1], :q_shape[2], :]

        strides_gt_one = tf.logical_or(tf.greater(strides[1], 1),
                                       tf.greater(strides[2], 1))
        blk_indices_crop = tf.cond(strides_gt_one, _strides_gt_one,
                                   _strides_one)
        y = tf.scatter_nd(blk_indices_crop, q, out_shape)
        return y

    return tf.cond(tf.equal(tf.size(blk_indices_), 0),
                   lambda: tf.zeros(out_shape, dtype=x.dtype), _conv_nonzero)
Esempio n. 2
0
def sparse_conv2d_matmul(x, w, blk_indices, strides, padding):
    """
    Performs 2D convolution using matrix multiplication on a sparse feature map.
    Naive python implementation of sparse convolution using gather and scatter.

    :param x:           [Tensor]  [N, H, W, C]. Input activation tensor, dtype float32.
    :param w:           [Tensor]  [I, J, C, K]. Convolution kernel, dtype float32.
    :param blk_indices: [Tensor]  [M, h, w, 3]. Block indices of rectangles.
    :param strides:     [list]    List of 4 int, convolution strides.
    :param padding:     [string]  `VALID` or `SAME`, padding method for sparse convolution.

    :return             [Tensor]  [N, H', W', C]. Convolution results.
    """
    blk_indices_ = tf.reshape(blk_indices, [-1, 3])
    blk_shape = tf.shape(blk_indices)
    ksize = tf.shape(w)

    # Calculate the block strides.
    bstrides = _calc_block_strides(blk_shape, ksize, strides)

    # Calculate the output size.
    x_shape = tf.shape(x)
    out_shape = calc_out_size_4d(x_shape, ksize, strides, padding)

    # Pad input.
    x_ = _pad_input(x,
                    ksize,
                    strides,
                    padding,
                    bsize=[1, blk_shape[1], blk_shape[2], 1],
                    bstrides=bstrides)

    # In matrix multiplication mode, the block patch should be the same as the kernel size.
    assert_shape = tf.assert_equal(
        tf.stack([blk_shape[1], blk_shape[2]]),
        tf.stack([ksize[0], ksize[1]]),
        message=
        'Expect blk_indices.shape[1] == w.shape[0] and blk_indices.shape[2] == w.shape[1].'
    )

    # Currently we do not support strides > 1 in this matrix multiplication mode. Could be supported
    # in the future.
    assert_strides = tf.assert_equal(tf.cast(
        tf.stack([strides[1], strides[2]]), tf.int64),
                                     tf.constant([1, 1], dtype=tf.int64),
                                     message='Strides > 1 not supported.')

    # Convolution when number of indices is larger than zero.
    def _conv_nonzero():
        # Gather patches.
        p = tf.gather_nd(x_, blk_indices_)
        p_ = tf.reshape(p, [-1, ksize[0] * ksize[1] * ksize[2]])

        # Convolution on patches.
        w_ = tf.reshape(w, [ksize[0] * ksize[1] * ksize[2], -1])
        q = tf.matmul(p_, w_)

        # Center locations.
        blk_indices_crop = blk_indices[:, 0, 0, :]

        #  Project back to an image.
        y = tf.scatter_nd(blk_indices_crop, q, out_shape)
        return y

    with tf.control_dependencies([assert_shape, assert_strides]):
        return tf.cond(tf.equal(tf.size(blk_indices_),
                                0), lambda: tf.zeros(out_shape, dtype=x.dtype),
                       _conv_nonzero)
Esempio n. 3
0
def calc_block_params(in_size, bsize, ksize, strides, padding, static=True):
    """
    Calculates block parameters for a single convolution layer.

    :param in_size:  [list]     List of 4 int. Size of the convolution input.
    :param bsize:    [list]     List of 4 int. Size of blocks, or downsample ratio.
    :param ksize:    [list]     List of 4 int. Sparse convolution kernel size.
    :param strides:  [list]     List of 4 int. Sparse convolution stride size.
                                Currently only supports when,
                                1) (bsize[1] - ksize[0]) % strides[1] == 0 and,
                                2) (bsize[2] - ksize[1]) % strides[2] == 0
    :param padding:  [string]   `VALID` or `SAME`, padding method for sparse convolution.

    :return          [tuple]
        bsize:
        bsize_out:
        boffset:
        bcount:
        bstrides:
    """

    assert ((bsize[1] - ksize[0]) % strides[1] == 0)
    assert ((bsize[2] - ksize[1]) % strides[2] == 0)

    bstrides = _calc_block_strides(bsize, ksize, strides)
    pad_h0, pad_h1, pad_w0, pad_w1 = calc_padding_4d(in_size, ksize, strides,
                                                     padding)
    h = in_size[1]
    w = in_size[2]
    # Make padding divides blocks.
    pad_h1 += (-h + bsize[1]) % bstrides[1]
    pad_w1 += (-w + bsize[2]) % bstrides[2]
    boffset = [-pad_h0, -pad_w0]
    x_pad_shape = [
        in_size[0], in_size[1] + pad_h0 + pad_h1, in_size[2] + pad_w0 + pad_w1,
        in_size[3]
    ]
    if static:
        out_shape = calc_out_size_4d_np(x_pad_shape,
                                        [bsize[1], bsize[2], 1, 1], bstrides,
                                        'VALID')
    else:
        out_shape = calc_out_size_4d(x_pad_shape, [bsize[1], bsize[2], 1, 1],
                                     bstrides, 'VALID')
    bcount = [out_shape[1], out_shape[2]]
    bsize_out = calc_out_size_4d_np(bsize, ksize, strides, 'VALID')
    bsize = bsize[1:3]
    bstrides = bstrides[1:3]
    bsize_out = bsize_out[1:3]
    # print('h w', h, w)
    # print('bcount', bcount)
    # print('bsize', bsize)
    # print('bsize_out', bsize_out)
    # print('boffset', boffset)
    # print('bstrides', bstrides)
    # print(pad_h0, pad_w0, boffset)
    if static:
        assert (pad_h0 == -boffset[0])
        assert (pad_w0 == -boffset[1])
    for i, siz in zip([0, 1], [h, w]):
        # make sure last block is inside
        err_msg = 'Making sure last block is inside boffset {} bstrides {} bcount {} size {}'.format(
            boffset[i], bstrides[i], bcount[i], siz)
        assert (boffset[i] + bstrides[i] * (bcount[i] - 1) < siz), err_msg
    return BlockParams(bsize=bsize,
                       bsize_out=bsize_out,
                       boffset=boffset,
                       bcount=bcount,
                       bstrides=bstrides)