def test_prepare_conv_args(self): [filter_shape, rank, strides, padding, dilations] = convolution_util.prepare_conv_args((3, 3), rank=2, strides=2, padding='same', dilations=(1, 1)) for arg in [filter_shape, strides, dilations]: self.assertLen(arg, rank) self.assertEqual(padding, 'SAME')
def _make_convolution_fn(rank, strides, padding, dilations): """Helper to create tf convolution op.""" [ _, rank, strides, padding, dilations, ] = convolution_util.prepare_conv_args(1, rank, strides, padding, dilations) def op(x, kernel): dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=dtype, name='kernel') return tf.nn.convolution( x, kernel, strides=strides, padding=padding, data_format='NHWC', dilations=dilations) return lambda x, kernel: nn_util_lib.batchify_op(op, rank + 1, x, kernel)
def _convolution_batch_nhwbc( x, kernel, rank, strides, padding, dilations, name): """Specialization of batch conv to NHWBC data format.""" with tf.name_scope(name or 'conv2d_nhwbc'): # Prepare arguments. [ _, # filter shape rank, _, # strides padding, dilations, ] = convolution_util.prepare_conv_args(1, rank, strides, padding, dilations) strides = prepare_strides(strides, rank + 2, arg_name='strides') dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=dtype, name='kernel') # Step 1: Transpose and double flatten kernel. # kernel.shape = B + F + [c, c']. Eg: [b, fh, fw, c, c'] kernel_shape = ps.shape(kernel) kernel_batch_shape, kernel_event_shape = ps.split( kernel_shape, num_or_size_splits=[-1, rank + 2]) kernel_batch_size = ps.reduce_prod(kernel_batch_shape) kernel_ndims = ps.rank(kernel) kernel_batch_ndims = kernel_ndims - rank - 2 perm = ps.concat([ ps.range(kernel_batch_ndims, kernel_batch_ndims + rank), ps.range(0, kernel_batch_ndims), ps.range(kernel_batch_ndims + rank, kernel_ndims), ], axis=0) # Eg, [1, 2, 0, 3, 4] kernel = tf.transpose(kernel, perm=perm) # F + B + [c, c'] kernel = tf.reshape( kernel, shape=ps.concat([ kernel_event_shape[:rank], [kernel_batch_size * kernel_event_shape[-2], kernel_event_shape[-1]], ], axis=0)) # F + [bc, c'] # Step 2: Double flatten x. # x.shape = N + D + B + [c] x_shape = ps.shape(x) [ x_sample_shape, x_rank_shape, x_batch_shape, x_channel_shape, ] = ps.split( x_shape, num_or_size_splits=[-1, rank, kernel_batch_ndims, 1]) x = tf.reshape( x, # N + D + B + [c] shape=ps.concat([ [ps.reduce_prod(x_sample_shape)], x_rank_shape, [ps.reduce_prod(x_batch_shape) * ps.reduce_prod(x_channel_shape)], ], axis=0)) # [n] + D + [bc] # Step 3: Apply convolution. y = tf.nn.depthwise_conv2d( x, kernel, strides=strides, padding=padding, data_format='NHWC', dilations=dilations) # SAME: y.shape = [n, h, w, bcc'] # VALID: y.shape = [n, h-fh+1, w-fw+1, bcc'] # Step 4: Reshape/reduce for output. y_shape = ps.shape(y) y = tf.reshape( y, shape=ps.concat([ x_sample_shape, y_shape[1:-1], kernel_batch_shape, kernel_event_shape[-2:], ], axis=0)) # N + D' + B + [c, c'] y = tf.reduce_sum(y, axis=-2) # N + D' + B + [c'] return y