def apply_cdna_kernels(image, kernels, dilation_rate=(1, 1)):
    """
    Args:
        image: A 4-D tensor of shape
            `[batch, in_height, in_width, in_channels]`.
        kernels: A 4-D of shape
            `[batch, kernel_size[0], kernel_size[1], num_transformed_images]`.

    Returns:
        A list of `num_transformed_images` 4-D tensors, each of shape
            `[batch, in_height, in_width, in_channels]`.
    """
    batch_size, height, width, color_channels = image.get_shape().as_list()
    batch_size, kernel_size_r, kernel_size_c, num_transformed_images = kernels.get_shape().as_list()
    kernel_size = [kernel_size_r, kernel_size_c]

    image_padded = pad2d(image, kernel_size, rate=dilation_rate, padding='SAME', mode='SYMMETRIC')
    # Treat the color channel dimension as the batch dimension since the same
    # transformation is applied to each color channel.
    # Treat the batch dimension as the channel dimension so that
    # depthwise_conv2d can apply a different transformation to each sample.
    kernels = tf.transpose(kernels, [1, 2, 0, 3])
    kernels = tf.reshape(kernels, [kernel_size[0], kernel_size[1], batch_size, num_transformed_images])
    # Swap the batch and channel dimensions.
    image_transposed = tf.transpose(image_padded, [3, 1, 2, 0])
    # Transform image.
    outputs = tf.nn.depthwise_conv2d(image_transposed, kernels, [1, 1, 1, 1], padding='VALID', rate=dilation_rate)
    # Transpose the dimensions to where they belong.
    outputs = tf.reshape(outputs, [color_channels, height, width, batch_size, num_transformed_images])
    outputs = tf.transpose(outputs, [4, 3, 1, 2, 0])
    outputs = tf.unstack(outputs, axis=0)
    return outputs
Example #2
0
def apply_dna_kernels_non_dilated(image, kernels):
    batch_size, height, width, color_channels = image.get_shape().as_list()
    batch_size, height, width, kernel_size, num_transformed_images = kernels.get_shape(
    ).as_list()
    # Flatten the spatial dimensions.
    kernels_reshaped = tf.reshape(kernels, [
        batch_size, height, width, kernel_size[0] * kernel_size[1],
        num_transformed_images
    ])
    image_padded = pad2d(image, kernel_size, padding='SAME', mode='SYMMETRIC')
    # Combine channel and batch dimensions into the first dimension.
    image_transposed = tf.transpose(image_padded, [3, 0, 1, 2])
    image_reshaped = flatten(image_transposed, 0, 1)[..., None]
    patches_reshaped = tf.extract_image_patches(image_reshaped,
                                                ksizes=[1] + kernel_size + [1],
                                                strides=[1] * 4,
                                                rates=[1] * 4,
                                                padding='VALID')
    # Separate channel and batch dimensions.
    patches = tf.reshape(patches_reshaped, [
        color_channels, batch_size, height, width,
        kernel_size[0] * kernel_size[1]
    ])
    # Reduce along the spatial dimensions of the kernel.
    outputs = tf.reduce_sum(patches[..., None] * kernels_reshaped[None, ...],
                            axis=-2)
    # Swap channel and transformation dimensions.
    outputs = tf.transpose(outputs, [4, 1, 2, 3, 0])
    outputs = tf.unstack(outputs, axis=0)
    return outputs
def apply_dna_kernels_dilated(image, kernels, dilation_rate=(1, 1)):
    dilation_rate = list(dilation_rate) if isinstance(dilation_rate, (tuple, list)) else [dilation_rate] * 2
    batch_size, height, width, color_channels = image.get_shape().as_list()
    batch_size, kernel_height, kernel_width, kernel_size, num_transformed_images = kernels.get_shape().as_list()
    # Flatten the spatial dimensions.
    kernels_reshaped = tf.reshape(kernels, [batch_size, kernel_height, kernel_width,
                                            kernel_size[0] * kernel_size[1], num_transformed_images])
    image_padded = pad2d(image, kernel_size, rate=dilation_rate, padding='SAME', mode='SYMMETRIC')
    # for dilation = [2, 2], this is equivalent to this:
    # small_images = [image[:, 0::2, 0::2, :], image[:, 0::2, 1::2, :], image[:, 1::2, 0::2, :], image[:, 1::2, 1::2, :]]
    small_images = tf.space_to_batch_nd(image_padded, dilation_rate, paddings=[[0, 0]] * 2)
    small_images = tf.reshape(small_images, [dilation_rate[0] * dilation_rate[1], batch_size,
                                             image_padded.get_shape().as_list()[1] // dilation_rate[0],
                                             image_padded.get_shape().as_list()[2] // dilation_rate[1],
                                             color_channels])
    small_images = tf.unstack(small_images, axis=0)
    small_outputs = []
    for small_image in small_images:
        # Combine channel and batch dimensions into the first dimension.
        image_transposed = tf.transpose(small_image, [3, 0, 1, 2])
        image_reshaped = flatten(image_transposed, 0, 1)[..., None]
        patches_reshaped = tf.extract_image_patches(image_reshaped, ksizes=[1] + kernel_size + [1],
                                                    strides=[1] * 4, rates=[1] * 4, padding='VALID')
        # Separate channel and batch dimensions.
        patches = tf.reshape(patches_reshaped, [color_channels, batch_size,
                                                height // dilation_rate[0], width // dilation_rate[1],
                                                kernel_size[0] * kernel_size[1]])
        # Reduce along the spatial dimensions of the kernel.
        outputs = tf.reduce_sum(patches[..., None] * kernels_reshaped[None, ...], axis=-2)
        # Swap channel and transformation dimensions.
        outputs = tf.transpose(outputs, [4, 1, 2, 3, 0])
        outputs = tf.unstack(outputs, axis=0)
        small_outputs.append(outputs)
    small_outputs = list(zip(*small_outputs))
    small_outputs = [tf.reshape(small_output, [dilation_rate[0] * dilation_rate[1] * batch_size,
                                               height // dilation_rate[0], width // dilation_rate[1], color_channels])
                     for small_output in small_outputs]
    outputs = [tf.batch_to_space_nd(small_output, dilation_rate, crops=[[0, 0]] * 2) for small_output in small_outputs]
    return outputs