def apply_cdna_kernels(image, kernels, dilation_rate=(1, 1)): """ Args: image: A 4-D tensor of shape `[batch, in_height, in_width, in_channels]`. kernels: A 4-D of shape `[batch, kernel_size[0], kernel_size[1], num_transformed_images]`. Returns: A list of `num_transformed_images` 4-D tensors, each of shape `[batch, in_height, in_width, in_channels]`. """ batch_size, height, width, color_channels = image.get_shape().as_list() batch_size, kernel_height, kernel_width, num_transformed_images = kernels.get_shape().as_list() kernel_size = [kernel_height, kernel_width] image_padded = pad2d(image, kernel_size, rate=dilation_rate, padding='SAME', mode='SYMMETRIC') # Treat the color channel dimension as the batch dimension since the same # transformation is applied to each color channel. # Treat the batch dimension as the channel dimension so that # depthwise_conv2d can apply a different transformation to each sample. kernels = tf.transpose(kernels, [1, 2, 0, 3]) kernels = tf.reshape(kernels, [kernel_size[0], kernel_size[1], batch_size, num_transformed_images]) # Swap the batch and channel dimensions. image_transposed = tf.transpose(image_padded, [3, 1, 2, 0]) # Transform image. outputs = tf.nn.depthwise_conv2d(image_transposed, kernels, [1, 1, 1, 1], padding='VALID', rate=dilation_rate) # Transpose the dimensions to where they belong. outputs = tf.reshape(outputs, [color_channels, height, width, batch_size, num_transformed_images]) outputs = tf.transpose(outputs, [4, 3, 1, 2, 0]) outputs = tf.unstack(outputs, axis=0) return outputs
def apply_dna_kernels(image, kernels, dilation_rate=(1, 1)): """ Args: image: A 4-D tensor of shape `[batch, in_height, in_width, in_channels]`. kernels: A 6-D of shape `[batch, in_height, in_width, kernel_size[0], kernel_size[1], num_transformed_images]`. Returns: A list of `num_transformed_images` 4-D tensors, each of shape `[batch, in_height, in_width, in_channels]`. """ dilation_rate = list(dilation_rate) if isinstance(dilation_rate, (tuple, list)) else [dilation_rate] * 2 batch_size, height, width, color_channels = image.get_shape().as_list() batch_size, height, width, kernel_height, kernel_width, num_transformed_images = kernels.get_shape().as_list() kernel_size = [kernel_height, kernel_width] # Flatten the spatial dimensions. kernels_reshaped = tf.reshape(kernels, [batch_size, height, width, kernel_size[0] * kernel_size[1], num_transformed_images]) image_padded = pad2d(image, kernel_size, rate=dilation_rate, padding='SAME', mode='SYMMETRIC') # Combine channel and batch dimensions into the first dimension. image_transposed = tf.transpose(image_padded, [3, 0, 1, 2]) image_reshaped = flatten(image_transposed, 0, 1)[..., None] patches_reshaped = tf.extract_image_patches(image_reshaped, ksizes=[1] + kernel_size + [1], strides=[1] * 4, rates=[1] + dilation_rate + [1], padding='VALID') # Separate channel and batch dimensions, and move channel dimension. patches_transposed = tf.reshape(patches_reshaped, [color_channels, batch_size, height, width, kernel_size[0] * kernel_size[1]]) patches = tf.transpose(patches_transposed, [1, 2, 3, 0, 4]) # Reduce along the spatial dimensions of the kernel. outputs = tf.matmul(patches, kernels_reshaped) outputs = tf.unstack(outputs, axis=-1) return outputs