def __init__(self, srgb_mat=None, kernel=5, c_filters=(3,), cfa_pattern='gbrg', residual=False, brightness=None, **kwargs): super().__init__() up = upsampling_kernel(cfa_pattern).reshape((1, 1, 4, 12)).astype(np.float32) self._upsampling_kernel = tf.convert_to_tensor(up) if srgb_mat is None: srgb_mat = np.eye(3, dtype=np.float32) self._srgb_mat = tf.convert_to_tensor(srgb_mat.T.reshape((1, 1, 3, 3))) self._demosaicing = layers.DemosaicingLayer(c_filters, kernel, 'leaky_relu', residual) self._brightness = brightness
def construct_model(self, random_init=False, kernel=5, trainable_upsampling=False, cfa_pattern='gbrg'): self._h = paramspec.ParamSpec({ 'random_init': (False, bool, None), 'kernel': (5, int, (3, 11)), 'trainable_upsampling': (False, bool, None), 'cfa_pattern': ('gbrg', str, {'gbrg', 'rggb', 'bggr'}) }) params = locals() self._h.update(**{k: params[k] for k in self._h.keys() if k in params}) # Initialize the upsampling kernel upk = upsampling_kernel(self._h.cfa_pattern) if self._h.random_init: # upk = np.random.normal(0, 0.1, (4, 12)) dmf = np.random.normal(0, 0.1, (self._h.kernel, self._h.kernel, 3, 3)) gamma_d1k = np.random.normal(0, 0.1, (3, 12)) gamma_d1b = np.zeros((12, )) gamma_d2k = np.random.normal(0, 0.1, (12, 3)) gamma_d2b = np.zeros((3,)) srgbk = np.eye(3) else: # Prepare demosaicing kernels (bilinear) dmf = bilin_kernel(self._h.kernel) # Prepare gamma correction kernels (obtained from a pre-trained toy model) gamma_d1k, gamma_d1b, gamma_d2k, gamma_d2b = gamma_kernels() # Example sRGB conversion table srgbk = np.array([[ 1.82691061, -0.65497452, -0.17193617], [-0.00683982, 1.33216381, -0.32532394], [ 0.06269717, -0.40055895, 1.33786178]]).transpose() # Up-sample the input back the full resolution h12 = tf.keras.layers.Conv2D(12, 1, kernel_initializer=tf.constant_initializer(upk), use_bias=False, activation=None, trainable=self._h.trainable_upsampling)(self.x) # Demosaicing pad = (self._h.kernel - 1) // 2 bayer = tf.nn.depth_to_space(h12, 2) bayer = tf.pad(bayer, tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]]), 'REFLECT') rgb = tf.keras.layers.Conv2D(3, self._h.kernel, kernel_initializer=tf.constant_initializer(dmf), use_bias=False, activation=None, padding='VALID')(bayer) # Color space conversion srgb = tf.keras.layers.Conv2D(3, 1, kernel_initializer=tf.constant_initializer(srgbk), use_bias=False, activation=None)(rgb,) # Gamma correction rgb_g0 = tf.keras.layers.Conv2D(12, 1, kernel_initializer=tf.constant_initializer(gamma_d1k), bias_initializer=tf.constant_initializer(gamma_d1b), use_bias=True, activation=tf.keras.activations.tanh)(srgb) y = tf.keras.layers.Conv2D(3, 1, kernel_initializer=tf.constant_initializer(gamma_d2k), bias_initializer=tf.constant_initializer(gamma_d2b), use_bias=True, activation=None)(rgb_g0) # self.y = tf.clip_by_value(self.yy, 0, 1, name='{}/y'.format(self.scoped_name)) self.y = tf.stop_gradient(tf.clip_by_value(y, 0, 1) - y) + y self._model = tf.keras.Model(inputs=[self.x], outputs=[self.y])
def construct_model(self, n_layers=15, kernel=3, n_features=64): self._h = paramspec.ParamSpec({ 'n_layers': (15, int, (1, 32)), 'kernel': (3, int, (3, 11)), 'n_features': (64, int, (4, 128)), }) params = locals() self._h.update(**{k: params[k] for k in self._h.keys() if k in params}) k_initializer = tf.keras.initializers.VarianceScaling # Initialize the upsampling kernel upk = upsampling_kernel() # Padding size pad = (self._h.kernel - 1) // 2 # Convolutions on the sub-sampled input tensor deep_x = self.x for r in range(self._h.n_layers): deep_y = tf.keras.layers.Conv2D(12 if r == self._h.n_layers - 1 else self._h.n_features, self._h.kernel, activation=tf.keras.activations.relu, padding='VALID', kernel_initializer=k_initializer)(deep_x) deep_x = tf.pad(deep_y, tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]]), 'REFLECT') # Up-sample the input h12 = tf.keras.layers.Conv2D(12, 1, kernel_initializer=tf.constant_initializer(upk), use_bias=False, activation=None, trainable=False)(self.x) bayer = tf.nn.depth_to_space(h12, 2) # Upscale the conv. features and concatenate with the input RGB channels features = tf.nn.depth_to_space(deep_x, 2) bayer_features = tf.concat((features, bayer), axis=3) # Project the concatenated 6-D features (R G B bayer from input + 3 channels from convolutions) pu = tf.keras.layers.Conv2D(self._h.n_features, self._h.kernel, kernel_initializer=k_initializer, use_bias=True, activation=tf.keras.activations.relu, padding='VALID', bias_initializer=tf.zeros_initializer)(bayer_features) # Final 1x1 conv to project each 64-D feature vector into the RGB colorspace pu = tf.pad(pu, tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]]), 'REFLECT') y = tf.keras.layers.Conv2D(3, 1, kernel_initializer=tf.ones_initializer, use_bias=False, activation=None, padding='VALID')(pu) # self.y = tf.clip_by_value(self.yy, 0, 1, name='{}/y'.format(self.scoped_name)) self.y = tf.stop_gradient(tf.clip_by_value(y, 0, 1) - y) + y self._model = tf.keras.Model(inputs=[self.x], outputs=[self.y])
def process(self, x, srgb_mat=None, cfa_pattern='gbrg', brightness='percentile'): kernel = 5 # Initialize upsampling and demosaicing kernels upk = upsampling_kernel(cfa_pattern).reshape((1, 1, 4, 12)) dmf = bilin_kernel(kernel) # Setup sRGB color conversion if srgb_mat is None: srgb_mat = np.eye(3) srgb_mat = srgb_mat.T.reshape((1, 1, 3, 3)) # Demosaicing & color space conversion pad = (kernel - 1) // 2 h12 = tf.nn.conv2d(x, upk, [1, 1, 1, 1], 'SAME') bayer = tf.nn.depth_to_space(h12, 2) bayer = tf.pad(bayer, tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]]), 'REFLECT') rgb = tf.nn.conv2d(bayer, dmf, [1, 1, 1, 1], 'VALID') # RGB -> sRGB rgb = tf.nn.conv2d(rgb, srgb_mat, [1, 1, 1, 1], 'SAME') # Brightness correction if brightness is not None: if brightness == 'percentile': percentile = 0.5 rgb -= np.percentile(rgb, percentile) rgb /= np.percentile(rgb, 100 - percentile) elif brightness == 'shift': mult = 0.25 / tf.reduce_mean(rgb) rgb *= mult else: raise ValueError('Brightness normalization not recognized!') # Gamma correction y = rgb y = tf.stop_gradient(tf.clip_by_value(y, 0, 1) - y) + y y = tf.pow(y, 1/2.2) return y
def set_cfa_pattern(self, cfa_pattern): if cfa_pattern is not None: cfa_pattern = cfa_pattern.lower() up = upsampling_kernel(cfa_pattern).reshape((1, 1, 4, 12)).astype(np.float32) self._model._upsampling_kernel = tf.convert_to_tensor(up) self._h.update(cfa_pattern=cfa_pattern)