def build(self, input_shape): input_shape, mask_shape = input_shape self.block_count = [ utils.divup(input_shape[1], self.block_stride[0]), utils.divup(input_shape[2], self.block_stride[1]) ] if len(input_shape) != 4: raise ValueError( f'Inputs should have rank 4. Received input shape: ' f'{input_shape}') if input_shape[3] is None: raise ValueError('The channel dimension of the inputs ' 'should be defined. Found `None`.') input_dim = int(input_shape[3]) kernel_shape = self.kernel_size + [input_dim, self.filters] self.kernel = self.add_weight(name='kernel', shape=kernel_shape, dtype=tf.float32, initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, trainable=True) if self.use_bias: self.bias = self.add_weight(name='bias', shape=(self.filters, ), dtype=tf.float32, initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, trainable=True) else: self.bias = None if self.use_var: output_shape = list( self.compute_output_shape([input_shape, mask_shape])) self.outputs = self.add_variable(name='outputs', shape=[self.batch_size] + output_shape[1:], dtype=tf.float32, initializer='zeros', trainable=False, use_resource=False)
def build(self, mask_shape): self.block_count = [ utils.divup(mask_shape[1], self.block_stride[0]), utils.divup(mask_shape[2], self.block_stride[1]) ]
def main(): """For testing/understanding sbnet.""" # tf.enable_eager_execution() # Specify input tensor dimensions and block-sparsity parameters batch = 4 hw = 256 channels = 64 blockSize = [16, 16] blockStride = [14, 14] blockOffset = [0, 0] blockCount = [ utils.divup(hw, blockStride[0]), utils.divup(hw, blockStride[1]) ] # build kwargs to simplify op calls inBlockParams = { "dynamic_bsize": blockSize, "dynamic_boffset": blockOffset, "dynamic_bstride": blockStride } outBlockParams = { "dynamic_bsize": [blockSize[0] - 2, blockSize[1] - 2], "dynamic_boffset": blockOffset, "dynamic_bstride": blockStride } # create a random mask representing attention/a priori sparsity # threshold the mask to a specified percentile sparsity mask = np.random.randn(batch, blockCount[0], blockCount[1], channels).astype(np.float32) threshold = np.percentile(mask, 90) sparseMask = np.greater(mask, threshold).astype(np.float32) # upsample the mask to full resolution upsampledMask = sparseMask.repeat( blockStride[0], # noqa axis=1).repeat(blockStride[1], axis=2) # create a random input tensor x = tf.constant( np.random.randn(batch, hw, hw, channels).astype(np.float32)) # create a random weight tensor w = tf.constant( np.random.randn(3, 3, channels, channels).astype(np.float32)) # reduce the mask to indices by using a fused pooling+indexing operation indices = sbnet.reduce_mask(mask, blockCount, tol=0.5, **inBlockParams) print("using gpu:", tf.test.is_gpu_available() and tf.test.is_built_with_cuda()) print("bin_counts:", indices.bin_counts) print("bin_counts:", indices.bin_counts.shape) print("active_block_indices:", indices.active_block_indices) print("active_block_indices:", indices.active_block_indices.shape) # stack active overlapping tiles to batch dimension blockStack = sbnet.sparse_gather(x, indices.bin_counts, indices.active_block_indices, transpose=True, **inBlockParams) print("block_stack:", blockStack.shape) # perform dense convolution on a sparse stack of tiles convBlocks = tf.nn.conv2d(blockStack, w, strides=[1, 1, 1, 1], padding='VALID', data_format='NCHW') # convBlocks = keras.layers.Conv2D(channels, (3, 3), padding='valid', # data_format='channels_first')(blockStack) # write/scatter the tiles back on top of original tensor. Note that the # output tensor is reduced by 1 on each side due to 'VALID' convolution validX = x[:, 1:hw - 1, 1:hw - 1, :] y = sbnet.sparse_scatter(convBlocks, indices.bin_counts, indices.active_block_indices, validX, transpose=True, add=False, atomic=False, **outBlockParams) if not tf.executing_eagerly(): sess = tf.Session() y_output, = sess.run([y])
def _compute_bcount(size, bstride): return [utils.divup(size[0], bstride[0]), utils.divup(size[1], bstride[1])]