def deterministic(): kern = tf.select(kernel_prune_mask, zeros, self.kernel) # Convolution with reparameterization trick outputs = nn.convolution(input=inputs, filter=kern, dilation_rate=self.dilation_rate, strides=self.strides, padding=self.padding.upper(), data_format=convert_data_format(self.data_format, self.rank+2)) # bias if self.bias is not None: if self.data_format == "channels_first": if self.rank == 1: bias = tf.reshape(self.bias, (1, self.filters, 1)) outputs += bias if self.rank == 2: outputs = tf.nn.bias_add(outputs, self.bias, data_format="NCHW") if self.rank == 3: outputs_shape = outputs.shape.as_list() outputs_4d = tf.reshape(outputs, [outputs_shape[0], outputs_shape[1], outputs_shape[2] * outputs_shape[3], otuputs_shape[4]]) outputs_4d = tf.nn.bias_add(outputs_4d, self.bias, data_format="NCHW") outputs = tf.reshape(outputs_4d, outputs_shape) else: outputs = tf.nn.bias_add(outputs, self.bias, data_format="NHWC") # Activation if self.activation is not None: return self.activation(outputs) return outputs
def build_psnr_metrics(with_y_true=True, addition_imgs=None, data_format=None): metrics = [] if with_y_true: def psnr(y_true, y_pred): if utils.convert_data_format(data_format) == "NCHW": y_true = tf.transpose(y_true, [0, 2, 3, 1]) y_pred = tf.transpose(y_pred, [0, 2, 3, 1]) return tf.image.psnr(y_true, y_pred, max_val=1.0) metrics.append(psnr) if addition_imgs is None: return metrics for name, img in addition_imgs.items(): if tf.rank(img) == 3: img = tf.expand_dims(img, axis=0) if utils.convert_data_format(data_format) == "NCHW": img = tf.transpose(img, [0, 2, 3, 1]) def psnr(_, y_pred): if utils.convert_data_format(data_format) == "NCHW": y_pred = tf.transpose(y_pred, [0, 2, 3, 1]) return tf.image.psnr(img, y_pred, max_val=1.0) psnr.__name__ = f"psnr_{name}" metrics.append(psnr) return metrics
def conv_sp_var_drop(): input2 = tf.multiply(inputs,inputs) # Clip value as suggested by the author's implementation if train_prune: kern = tf.select(kernel_prune_mask, self.kernel, kernel_clip_mask) # Convolution with reparameterization trick theta = nn.convolution(input=inputs, filter=kern, dilation_rate=self.dilation_rate, strides=self.strides, padding=self.padding.upper(), data_format=convert_data_format(self.data_format, self.rank+2)) sigma = tf.sqrt(nn.convolution(input=input2, filter=tf.exp(log_alpha)*kern*kern, dilation_rate=self.dilation_rate, strides=self.strides, padding=self.padding.upper(), data_format=convert_data_format(self.data_format, self.rank+2))) noise = tf.random_normal(shape=theta.shape.as_list()) outputs = theta + noise * sigma # bias if self.bias is not None: if self.data_format == "channels_first": if self.rank == 1: bias = tf.reshape(self.bias, (1, self.filters, 1)) outputs += bias if self.rank == 2: outputs = tf.nn.bias_add(outputs, self.bias, data_format="NCHW") if self.rank == 3: outputs_shape = outputs.shape.as_list() outputs_4d = tf.reshape(outputs, [outputs_shape[0], outputs_shape[1], outputs_shape[2] * outputs_shape[3], otuputs_shape[4]]) outputs_4d = tf.nn.bias_add(outputs_4d, self.bias, data_format="NCHW") outputs = tf.reshape(outputs_4d, outputs_shape) else: outputs = tf.nn.bias_add(outputs, self.bias, data_format="NHWC") # Activation if self.activation is not None: return self.activation(outputs) return outputs
def conv_var_drop(): input2 = tf.multiply(inputs,inputs) alpha = tf.clip_by_value(self.alpha, 0, 1) # Convolution with reparameterization trick theta = nn.convolution(input=inputs, filter=self.kernel, dilation_rate=self.dilation_rate, strides=self.strides, padding=self.padding.upper(), data_format=convert_data_format(self.data_format, self.rank+2)) sigma = tf.sqrt(eps+nn.convolution(input=input2, filter=alpha*self.kernel*self.kernel, dilation_rate=self.dilation_rate, strides=self.strides, padding=self.padding.upper(), data_format=convert_data_format(self.data_format, self.rank+2))) noise = tf.random_normal(shape=theta.shape.as_list()) outputs = theta + noise * sigma # bias if self.bias is not None: if self.data_format == "channels_first": if self.rank == 1: bias = tf.reshape(self.bias, (1, self.filters, 1)) outputs += bias if self.rank == 2: outputs = tf.nn.bias_add(outputs, self.bias, data_format="NCHW") if self.rank == 3: outputs_shape = outputs.shape.as_list() outputs_4d = tf.reshape(outputs, [outputs_shape[0], outputs_shape[1], outputs_shape[2] * outputs_shape[3], otuputs_shape[4]]) outputs_4d = tf.nn.bias_add(outputs_4d, self.bias, data_format="NCHW") outputs = tf.reshape(outputs_4d, outputs_shape) else: outputs = tf.nn.bias_add(outputs, self.bias, data_format="NHWC") # Activation if self.activation is not None: return self.activation(outputs) return outputs
def __init__( self, filters, kernel_size, strides=1, padding_mode="zero", data_format="NHWC", use_bias=True, **kwargs, ): super().__init__(trainable=True, **kwargs) self.filters = int(filters) self.kernel_size = int(kernel_size) self.strides = strides self.padding_mode = str(padding_mode.lower()) if self.padding_mode == "zero": self.padding_mode = "constant" self.data_format = str(utils.convert_data_format(data_format)) self.use_bias = bool(use_bias)
def build_skip_net( input_channels=2, output_channels=3, levels=5, down_channels=128, up_channels=128, skip_channels=4, down_sizes=3, up_sizes=3, skip_sizes=1, downsample_modes="stride", upsample_modes="nearest", padding_mode="zero", use_sigmoid=True, use_bias=True, use_1x1up=True, data_format=None, activations=tf.keras.layers.LeakyReLU, ): """Build an autoencoder network with skip connections. Arguments: input_channels: Integer output_channels: Integer levels: Integer. Number of encoder and decoder pairs down_channels: An integer or tuple/list of integers. How many channels in each encoder. up_channels: An integer or tuple/list of integers. How many channels in each decoder. skip_channels: An integer or tuple/list of integers. How many channels in each skip connection. down_sizes: An integer or tuple/list of integers. `kernel_size` of each encoder. up_sizes: An integer or tuple/list of integers. `kernel_size` of each decoder. skip_sizes: An integer or tuple/list of integers. `kernel_size` of each skip connection. downsample_modes: A string or tuple/list of strings. One of `"stride"`. upsample_modes: A string or tuple/list of strings. One of `"nearest"` or `"bilinear"`. padding_mode: One of `"constant"`, `"reflect"`, `"symmetric"`, or `"zero"` (case-insensitive). use_sigmoid: Boolean use_bias: Boolean, whether the layer uses a bias vector. use_1x1up: Boolean activations: Activation function to use. Returns: A `tf.keras.Model`. """ data_format = utils.convert_data_format(data_format) keras_data_format = utils.convert_data_format(data_format, False) channels_axis = -1 input_shape = (None, None, input_channels) if data_format == "NCHW": channels_axis = 1 input_shape = (input_channels, None, None) down_channels = tf.constant(down_channels, shape=levels) up_channels = tf.constant(up_channels, shape=levels) skip_channels = tf.constant(skip_channels, shape=levels) down_sizes = tf.constant(down_sizes, shape=levels) up_sizes = tf.constant(up_sizes, shape=levels) skip_sizes = tf.constant(skip_sizes, shape=levels) downsample_modes = (tf.constant(downsample_modes, shape=levels, dtype=tf.string).numpy().astype(str)) upsample_modes = (tf.constant(upsample_modes, shape=levels, dtype=tf.string).numpy().astype(str)) inputs = tf.keras.layers.Input(shape=input_shape) # First, we add layers along the deeper branch. deeper_startnodes = [None] * (levels + 1) deeper_startnodes[0] = inputs for i in range(levels): with tf.name_scope(f"deeper_{i}"): output = layers.ConvWithPad2D( down_channels[i], down_sizes[i], strides=2, padding_mode=padding_mode, use_bias=use_bias, data_format=data_format, )(deeper_startnodes[i]) output = tf.keras.layers.BatchNormalization( axis=channels_axis)(output) output = activations()(output) output = layers.ConvWithPad2D( down_channels[i], down_sizes[i], padding_mode=padding_mode, use_bias=use_bias, data_format=data_format, )(output) output = tf.keras.layers.BatchNormalization( axis=channels_axis)(output) deeper_startnodes[i + 1] = activations()(output) # Second, we add skip connections (if any) to deeper main nodes. skip_nodes = [None] * (levels) for i in range(levels): with tf.name_scope(f"skip_{i}"): if skip_channels[i] != 0: output = layers.ConvWithPad2D( skip_channels[i], skip_sizes[i], padding_mode=padding_mode, use_bias=use_bias, data_format=data_format, )(deeper_startnodes[i]) output = tf.keras.layers.BatchNormalization( axis=channels_axis)(output) skip_nodes[i] = activations()(output) # Finally, we concat skip connections and deeper (if any) or append # deeper (if there's no skip connections). Note that some final layers of each deeper # has to be connected with the ending node of the sublayers. Therefore in this loop, # `deeper_endnodes` will be built from the deepest layer first. deeper_endnodes = [None] * (levels + 1) deeper_endnodes[-1] = deeper_startnodes[-1] # Reversed loop because we build the deepest layer first. for i in range(levels - 1, -1, -1): with tf.name_scope(f"deeper_{i}"): # Upsampling before concatenation. deeper_endnodes[i + 1] = tf.keras.layers.UpSampling2D( interpolation=upsample_modes[i], data_format=keras_data_format)(deeper_endnodes[i + 1]) output = (tf.keras.layers.Concatenate( axis=channels_axis)([skip_nodes[i], deeper_endnodes[i + 1]]) if skip_channels[i] != 0 else deeper_endnodes[i + 1]) # Some final layers for deeper. output = tf.keras.layers.BatchNormalization( axis=channels_axis)(output) output = layers.ConvWithPad2D( up_channels[i], up_sizes[i], padding_mode=padding_mode, use_bias=use_bias, data_format=data_format, )(output) output = tf.keras.layers.BatchNormalization( axis=channels_axis)(output) output = activations()(output) if use_1x1up: output = layers.ConvWithPad2D( up_channels[i], 1, padding_mode=padding_mode, use_bias=use_bias, data_format=data_format, )(output) output = tf.keras.layers.BatchNormalization( axis=channels_axis)(output) output = activations()(output) deeper_endnodes[i] = output # Final touches outputs = layers.ConvWithPad2D( output_channels, 1, padding_mode=padding_mode, use_bias=use_bias, data_format=data_format, )(deeper_endnodes[0]) if use_sigmoid: outputs = tf.keras.layers.Activation(tf.nn.sigmoid)(outputs) return tf.keras.Model(inputs=inputs, outputs=outputs)
def psnr(y_true, y_pred): if utils.convert_data_format(data_format) == "NCHW": y_true = tf.transpose(y_true, [0, 2, 3, 1]) y_pred = tf.transpose(y_pred, [0, 2, 3, 1]) return tf.image.psnr(y_true, y_pred, max_val=1.0)