def construct_model(self, random_init=False, kernel=5, trainable_upsampling=False, cfa_pattern='gbrg'): self.trainable_upsampling = trainable_upsampling self.cfa_pattern = cfa_pattern with self.graph.as_default(): with tf.variable_scope('{}'.format(self.scoped_name)): # Initialize the upsampling kernel upk = upsampling_kernel(cfa_pattern) if random_init: # upk = np.random.normal(0, 0.1, (4, 12)) dmf = np.random.normal(0, 0.1, (kernel, kernel, 3, 3)) gamma_d1k = np.random.normal(0, 0.1, (3, 12)) gamma_d1b = np.zeros((12,)) gamma_d2k = np.random.normal(0, 0.1, (12, 3)) gamma_d2b = np.zeros((3,)) srgbk = np.eye(3) else: # Prepare demosaicing kernels (bilinear) dmf = bilin_kernel(kernel) # Prepare gamma correction kernels (obtained from a pre-trained toy model) gamma_d1k, gamma_d1b, gamma_d2k, gamma_d2b = gamma_kernels() # Example sRGB conversion table srgbk = np.array([[1.82691061, -0.65497452, -0.17193617], [-0.00683982, 1.33216381, -0.32532394], [0.06269717, -0.40055895, 1.33786178]]).transpose() # Up-sample the input back the full resolution with tf.variable_scope('upsampling'): h12 = tf.layers.conv2d(self.x, 12, 1, kernel_initializer=tf.constant_initializer(upk), use_bias=False, activation=None, name='conv_h12', trainable=trainable_upsampling) # Demosaicing with tf.variable_scope('demosaicing'): pad = (kernel - 1) // 2 bayer = tf.depth_to_space(h12, 2) bayer = tf.pad(bayer, tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]]), 'REFLECT') rgb = tf.layers.conv2d(bayer, 3, kernel, kernel_initializer=tf.constant_initializer(dmf), use_bias=False, activation=None, name='conv_demo', padding='VALID') # Color space conversion with tf.variable_scope('rgb2sRGB'): srgb = tf.layers.conv2d(rgb, 3, 1, kernel_initializer=tf.constant_initializer(srgbk), use_bias=False, activation=None, name='conv_sRGB') # Gamma correction with tf.variable_scope('gamma'): rgb_g0 = tf.layers.conv2d(srgb, 12, 1, kernel_initializer=tf.constant_initializer(gamma_d1k), bias_initializer=tf.constant_initializer(gamma_d1b), use_bias=True, activation=tf.nn.tanh, name='conv_encode') self.yy = tf.layers.conv2d(rgb_g0, 3, 1, kernel_initializer=tf.constant_initializer(gamma_d2k), bias_initializer=tf.constant_initializer(gamma_d2b), use_bias=True, activation=None, name='conv_decode') self.y = tf.clip_by_value(self.yy, 0, 1, name='{}/y'.format(self.scoped_name))
def construct_model(self, n_layers=15, kernel=3, n_features=64): with self.graph.as_default(): with tf.name_scope('{}'.format(self.scoped_name)): k_initializer = tf.variance_scaling_initializer # Initialize the upsampling kernel upk = upsampling_kernel() # Padding size pad = (kernel - 1) // 2 # Convolutions on the sub-sampled input tensor deep_x = self.x for r in range(n_layers): deep_y = tf.layers.conv2d(deep_x, 12 if r == n_layers - 1 else n_features, kernel, activation=tf.nn.relu, name='{}/conv{}'.format(self.scoped_name, r), padding='VALID', kernel_initializer=k_initializer) # print('CNN layer out: {}'.format(deep_y.shape)) deep_x = tf.pad(deep_y, tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]]), 'REFLECT') # Up-sample the input h12 = tf.layers.conv2d(self.x, 12, 1, kernel_initializer=tf.constant_initializer(upk), use_bias=False, activation=None, name='{}/conv_h12'.format(self.scoped_name), trainable=False) bayer = tf.depth_to_space(h12, 2, name="{}/upscaled_bayer".format(self.scoped_name)) # Upscale the conv. features and concatenate with the input RGB channels features = tf.depth_to_space(deep_x, 2, name='{}/upscaled_features'.format(self.scoped_name)) bayer_features = tf.concat((features, bayer), axis=3) print('Final deep X: {}'.format(deep_x.shape)) print('Bayer shape: {}'.format(bayer.shape)) print('Features shape: {}'.format(features.shape)) print('Concat shape: {}'.format(bayer_features.shape)) # Project the concatenated 6-D features (R G B bayer from input + 3 channels from convolutions) pu = tf.layers.conv2d(bayer_features, n_features, kernel, kernel_initializer=k_initializer, use_bias=True, activation=tf.nn.relu, name='{}/conv_postupscale'.format(self.scoped_name), padding='VALID', bias_initializer=tf.zeros_initializer) print('Post upscale: {}'.format(pu.shape)) # Final 1x1 conv to project each 64-D feature vector into the RGB colorspace pu = tf.pad(pu, tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]]), 'REFLECT') rgb = tf.layers.conv2d(pu, 3, 1, kernel_initializer=tf.ones_initializer, use_bias=False, activation=None, name='{}/conv_final'.format(self.scoped_name), padding='VALID') print('RGB affine: {}'.format(rgb.shape)) self.yy = rgb print('Y: {}'.format(self.yy.shape)) self.y = tf.clip_by_value(self.yy, 0, 1, name='{}/y'.format(self.scoped_name))