def apply_cdna_kernels(image, kernels, dilation_rate=(1, 1)): """ Args: image: A 4-D tensor of shape `[batch, in_height, in_width, in_channels]`. kernels: A 4-D of shape `[batch, kernel_size[0], kernel_size[1], num_transformed_images]`. Returns: A list of `num_transformed_images` 4-D tensors, each of shape `[batch, in_height, in_width, in_channels]`. """ batch_size, height, width, color_channels = image.get_shape().as_list() batch_size, kernel_size_r, kernel_size_c, num_transformed_images = kernels.get_shape().as_list() kernel_size = [kernel_size_r, kernel_size_c] image_padded = pad2d(image, kernel_size, rate=dilation_rate, padding='SAME', mode='SYMMETRIC') # Treat the color channel dimension as the batch dimension since the same # transformation is applied to each color channel. # Treat the batch dimension as the channel dimension so that # depthwise_conv2d can apply a different transformation to each sample. kernels = tf.transpose(kernels, [1, 2, 0, 3]) kernels = tf.reshape(kernels, [kernel_size[0], kernel_size[1], batch_size, num_transformed_images]) # Swap the batch and channel dimensions. image_transposed = tf.transpose(image_padded, [3, 1, 2, 0]) # Transform image. outputs = tf.nn.depthwise_conv2d(image_transposed, kernels, [1, 1, 1, 1], padding='VALID', rate=dilation_rate) # Transpose the dimensions to where they belong. outputs = tf.reshape(outputs, [color_channels, height, width, batch_size, num_transformed_images]) outputs = tf.transpose(outputs, [4, 3, 1, 2, 0]) outputs = tf.unstack(outputs, axis=0) return outputs
def apply_dna_kernels_non_dilated(image, kernels): batch_size, height, width, color_channels = image.get_shape().as_list() batch_size, height, width, kernel_size, num_transformed_images = kernels.get_shape( ).as_list() # Flatten the spatial dimensions. kernels_reshaped = tf.reshape(kernels, [ batch_size, height, width, kernel_size[0] * kernel_size[1], num_transformed_images ]) image_padded = pad2d(image, kernel_size, padding='SAME', mode='SYMMETRIC') # Combine channel and batch dimensions into the first dimension. image_transposed = tf.transpose(image_padded, [3, 0, 1, 2]) image_reshaped = flatten(image_transposed, 0, 1)[..., None] patches_reshaped = tf.extract_image_patches(image_reshaped, ksizes=[1] + kernel_size + [1], strides=[1] * 4, rates=[1] * 4, padding='VALID') # Separate channel and batch dimensions. patches = tf.reshape(patches_reshaped, [ color_channels, batch_size, height, width, kernel_size[0] * kernel_size[1] ]) # Reduce along the spatial dimensions of the kernel. outputs = tf.reduce_sum(patches[..., None] * kernels_reshaped[None, ...], axis=-2) # Swap channel and transformation dimensions. outputs = tf.transpose(outputs, [4, 1, 2, 3, 0]) outputs = tf.unstack(outputs, axis=0) return outputs
def apply_dna_kernels_dilated(image, kernels, dilation_rate=(1, 1)): dilation_rate = list(dilation_rate) if isinstance(dilation_rate, (tuple, list)) else [dilation_rate] * 2 batch_size, height, width, color_channels = image.get_shape().as_list() batch_size, kernel_height, kernel_width, kernel_size, num_transformed_images = kernels.get_shape().as_list() # Flatten the spatial dimensions. kernels_reshaped = tf.reshape(kernels, [batch_size, kernel_height, kernel_width, kernel_size[0] * kernel_size[1], num_transformed_images]) image_padded = pad2d(image, kernel_size, rate=dilation_rate, padding='SAME', mode='SYMMETRIC') # for dilation = [2, 2], this is equivalent to this: # small_images = [image[:, 0::2, 0::2, :], image[:, 0::2, 1::2, :], image[:, 1::2, 0::2, :], image[:, 1::2, 1::2, :]] small_images = tf.space_to_batch_nd(image_padded, dilation_rate, paddings=[[0, 0]] * 2) small_images = tf.reshape(small_images, [dilation_rate[0] * dilation_rate[1], batch_size, image_padded.get_shape().as_list()[1] // dilation_rate[0], image_padded.get_shape().as_list()[2] // dilation_rate[1], color_channels]) small_images = tf.unstack(small_images, axis=0) small_outputs = [] for small_image in small_images: # Combine channel and batch dimensions into the first dimension. image_transposed = tf.transpose(small_image, [3, 0, 1, 2]) image_reshaped = flatten(image_transposed, 0, 1)[..., None] patches_reshaped = tf.extract_image_patches(image_reshaped, ksizes=[1] + kernel_size + [1], strides=[1] * 4, rates=[1] * 4, padding='VALID') # Separate channel and batch dimensions. patches = tf.reshape(patches_reshaped, [color_channels, batch_size, height // dilation_rate[0], width // dilation_rate[1], kernel_size[0] * kernel_size[1]]) # Reduce along the spatial dimensions of the kernel. outputs = tf.reduce_sum(patches[..., None] * kernels_reshaped[None, ...], axis=-2) # Swap channel and transformation dimensions. outputs = tf.transpose(outputs, [4, 1, 2, 3, 0]) outputs = tf.unstack(outputs, axis=0) small_outputs.append(outputs) small_outputs = list(zip(*small_outputs)) small_outputs = [tf.reshape(small_output, [dilation_rate[0] * dilation_rate[1] * batch_size, height // dilation_rate[0], width // dilation_rate[1], color_channels]) for small_output in small_outputs] outputs = [tf.batch_to_space_nd(small_output, dilation_rate, crops=[[0, 0]] * 2) for small_output in small_outputs] return outputs