예제 #1
0
 def deterministic():
     kern = tf.select(kernel_prune_mask, zeros, self.kernel)
     # Convolution with reparameterization trick
     outputs = nn.convolution(input=inputs,
                  filter=kern,
                  dilation_rate=self.dilation_rate,
                  strides=self.strides,
                  padding=self.padding.upper(),
                  data_format=convert_data_format(self.data_format, self.rank+2))
     # bias
     if self.bias is not None:
         if self.data_format == "channels_first":
             if self.rank == 1:
                 bias = tf.reshape(self.bias, (1, self.filters, 1))
                 outputs += bias
             if self.rank == 2:
                 outputs = tf.nn.bias_add(outputs, self.bias, data_format="NCHW")
             if self.rank == 3:
                 outputs_shape = outputs.shape.as_list()
                 outputs_4d = tf.reshape(outputs,
                                         [outputs_shape[0], outputs_shape[1],
                                          outputs_shape[2] * outputs_shape[3],
                                          otuputs_shape[4]])
                 outputs_4d = tf.nn.bias_add(outputs_4d, self.bias, data_format="NCHW")
                 outputs = tf.reshape(outputs_4d, outputs_shape)
         else:
             outputs = tf.nn.bias_add(outputs, self.bias, data_format="NHWC")
     # Activation
     if self.activation is not None:
         return self.activation(outputs)
     return outputs
예제 #2
0
def build_psnr_metrics(with_y_true=True, addition_imgs=None, data_format=None):
    metrics = []
    if with_y_true:

        def psnr(y_true, y_pred):
            if utils.convert_data_format(data_format) == "NCHW":
                y_true = tf.transpose(y_true, [0, 2, 3, 1])
                y_pred = tf.transpose(y_pred, [0, 2, 3, 1])
            return tf.image.psnr(y_true, y_pred, max_val=1.0)

        metrics.append(psnr)

    if addition_imgs is None:
        return metrics

    for name, img in addition_imgs.items():
        if tf.rank(img) == 3:
            img = tf.expand_dims(img, axis=0)
        if utils.convert_data_format(data_format) == "NCHW":
            img = tf.transpose(img, [0, 2, 3, 1])

        def psnr(_, y_pred):
            if utils.convert_data_format(data_format) == "NCHW":
                y_pred = tf.transpose(y_pred, [0, 2, 3, 1])
            return tf.image.psnr(img, y_pred, max_val=1.0)

        psnr.__name__ = f"psnr_{name}"

        metrics.append(psnr)

    return metrics
예제 #3
0
 def conv_sp_var_drop():
     input2 = tf.multiply(inputs,inputs)
     # Clip value as suggested by the author's implementation
     if train_prune:
         kern = tf.select(kernel_prune_mask, self.kernel, kernel_clip_mask)
     # Convolution with reparameterization trick
     theta = nn.convolution(input=inputs,
                  filter=kern,
                  dilation_rate=self.dilation_rate,
                  strides=self.strides,
                  padding=self.padding.upper(),
                  data_format=convert_data_format(self.data_format, self.rank+2))
     sigma = tf.sqrt(nn.convolution(input=input2,
                          filter=tf.exp(log_alpha)*kern*kern,
                          dilation_rate=self.dilation_rate,
                          strides=self.strides,
                          padding=self.padding.upper(),
                          data_format=convert_data_format(self.data_format, self.rank+2)))
     noise = tf.random_normal(shape=theta.shape.as_list())
     outputs = theta + noise * sigma
     # bias
     if self.bias is not None:
         if self.data_format == "channels_first":
             if self.rank == 1:
                 bias = tf.reshape(self.bias, (1, self.filters, 1))
                 outputs += bias
             if self.rank == 2:
                 outputs = tf.nn.bias_add(outputs, self.bias, data_format="NCHW")
             if self.rank == 3:
                 outputs_shape = outputs.shape.as_list()
                 outputs_4d = tf.reshape(outputs,
                                         [outputs_shape[0], outputs_shape[1],
                                          outputs_shape[2] * outputs_shape[3],
                                          otuputs_shape[4]])
                 outputs_4d = tf.nn.bias_add(outputs_4d, self.bias, data_format="NCHW")
                 outputs = tf.reshape(outputs_4d, outputs_shape)
         else:
             outputs = tf.nn.bias_add(outputs, self.bias, data_format="NHWC")
     # Activation
     if self.activation is not None:
         return self.activation(outputs)
     return outputs
 def conv_var_drop():
     input2 = tf.multiply(inputs,inputs)
     alpha = tf.clip_by_value(self.alpha, 0, 1)
     # Convolution with reparameterization trick
     theta = nn.convolution(input=inputs,
                  filter=self.kernel,
                  dilation_rate=self.dilation_rate,
                  strides=self.strides,
                  padding=self.padding.upper(),
                  data_format=convert_data_format(self.data_format, self.rank+2))
     sigma = tf.sqrt(eps+nn.convolution(input=input2,
                          filter=alpha*self.kernel*self.kernel,
                          dilation_rate=self.dilation_rate,
                          strides=self.strides,
                          padding=self.padding.upper(),
                          data_format=convert_data_format(self.data_format, self.rank+2)))
     noise = tf.random_normal(shape=theta.shape.as_list())
     outputs = theta + noise * sigma
     # bias
     if self.bias is not None:
         if self.data_format == "channels_first":
             if self.rank == 1:
                 bias = tf.reshape(self.bias, (1, self.filters, 1))
                 outputs += bias
             if self.rank == 2:
                 outputs = tf.nn.bias_add(outputs, self.bias, data_format="NCHW")
             if self.rank == 3:
                 outputs_shape = outputs.shape.as_list()
                 outputs_4d = tf.reshape(outputs,
                                         [outputs_shape[0], outputs_shape[1],
                                          outputs_shape[2] * outputs_shape[3],
                                          otuputs_shape[4]])
                 outputs_4d = tf.nn.bias_add(outputs_4d, self.bias, data_format="NCHW")
                 outputs = tf.reshape(outputs_4d, outputs_shape)
         else:
             outputs = tf.nn.bias_add(outputs, self.bias, data_format="NHWC")
     # Activation
     if self.activation is not None:
         return self.activation(outputs)
     return outputs
예제 #5
0
 def __init__(
     self,
     filters,
     kernel_size,
     strides=1,
     padding_mode="zero",
     data_format="NHWC",
     use_bias=True,
     **kwargs,
 ):
     super().__init__(trainable=True, **kwargs)
     self.filters = int(filters)
     self.kernel_size = int(kernel_size)
     self.strides = strides
     self.padding_mode = str(padding_mode.lower())
     if self.padding_mode == "zero":
         self.padding_mode = "constant"
     self.data_format = str(utils.convert_data_format(data_format))
     self.use_bias = bool(use_bias)
예제 #6
0
def build_skip_net(
    input_channels=2,
    output_channels=3,
    levels=5,
    down_channels=128,
    up_channels=128,
    skip_channels=4,
    down_sizes=3,
    up_sizes=3,
    skip_sizes=1,
    downsample_modes="stride",
    upsample_modes="nearest",
    padding_mode="zero",
    use_sigmoid=True,
    use_bias=True,
    use_1x1up=True,
    data_format=None,
    activations=tf.keras.layers.LeakyReLU,
):
    """Build an autoencoder network with skip connections.

    Arguments:
        input_channels: Integer
        output_channels: Integer
        levels: Integer. Number of encoder and decoder pairs
        down_channels: An integer or tuple/list of integers. How many channels
            in each encoder.
        up_channels: An integer or tuple/list of integers. How many channels
            in each decoder.
        skip_channels: An integer or tuple/list of integers. How many channels
            in each skip connection.
        down_sizes: An integer or tuple/list of integers. `kernel_size` of
            each encoder.
        up_sizes: An integer or tuple/list of integers. `kernel_size` of
            each decoder.
        skip_sizes: An integer or tuple/list of integers. `kernel_size` of
            each skip connection.
        downsample_modes: A string or tuple/list of strings. One of `"stride"`.
        upsample_modes: A string or tuple/list of strings. One of `"nearest"`
            or `"bilinear"`.
        padding_mode: One of `"constant"`, `"reflect"`, `"symmetric"`,
            or `"zero"` (case-insensitive).
        use_sigmoid: Boolean
        use_bias: Boolean, whether the layer uses a bias vector.
        use_1x1up: Boolean
        activations: Activation function to use.

    Returns:
        A `tf.keras.Model`.

    """
    data_format = utils.convert_data_format(data_format)
    keras_data_format = utils.convert_data_format(data_format, False)
    channels_axis = -1
    input_shape = (None, None, input_channels)
    if data_format == "NCHW":
        channels_axis = 1
        input_shape = (input_channels, None, None)

    down_channels = tf.constant(down_channels, shape=levels)
    up_channels = tf.constant(up_channels, shape=levels)
    skip_channels = tf.constant(skip_channels, shape=levels)

    down_sizes = tf.constant(down_sizes, shape=levels)
    up_sizes = tf.constant(up_sizes, shape=levels)
    skip_sizes = tf.constant(skip_sizes, shape=levels)

    downsample_modes = (tf.constant(downsample_modes,
                                    shape=levels,
                                    dtype=tf.string).numpy().astype(str))
    upsample_modes = (tf.constant(upsample_modes,
                                  shape=levels,
                                  dtype=tf.string).numpy().astype(str))

    inputs = tf.keras.layers.Input(shape=input_shape)

    # First, we add layers along the deeper branch.
    deeper_startnodes = [None] * (levels + 1)
    deeper_startnodes[0] = inputs

    for i in range(levels):
        with tf.name_scope(f"deeper_{i}"):
            output = layers.ConvWithPad2D(
                down_channels[i],
                down_sizes[i],
                strides=2,
                padding_mode=padding_mode,
                use_bias=use_bias,
                data_format=data_format,
            )(deeper_startnodes[i])
            output = tf.keras.layers.BatchNormalization(
                axis=channels_axis)(output)
            output = activations()(output)

            output = layers.ConvWithPad2D(
                down_channels[i],
                down_sizes[i],
                padding_mode=padding_mode,
                use_bias=use_bias,
                data_format=data_format,
            )(output)
            output = tf.keras.layers.BatchNormalization(
                axis=channels_axis)(output)
            deeper_startnodes[i + 1] = activations()(output)

    # Second, we add skip connections (if any) to deeper main nodes.
    skip_nodes = [None] * (levels)

    for i in range(levels):
        with tf.name_scope(f"skip_{i}"):
            if skip_channels[i] != 0:
                output = layers.ConvWithPad2D(
                    skip_channels[i],
                    skip_sizes[i],
                    padding_mode=padding_mode,
                    use_bias=use_bias,
                    data_format=data_format,
                )(deeper_startnodes[i])
                output = tf.keras.layers.BatchNormalization(
                    axis=channels_axis)(output)
                skip_nodes[i] = activations()(output)

    # Finally, we concat skip connections and deeper (if any) or append
    # deeper (if there's no skip connections). Note that some final layers of each deeper
    # has to be connected with the ending node of the sublayers. Therefore in this loop,
    # `deeper_endnodes` will be built from the deepest layer first.

    deeper_endnodes = [None] * (levels + 1)
    deeper_endnodes[-1] = deeper_startnodes[-1]

    # Reversed loop because we build the deepest layer first.
    for i in range(levels - 1, -1, -1):
        with tf.name_scope(f"deeper_{i}"):
            # Upsampling before concatenation.
            deeper_endnodes[i + 1] = tf.keras.layers.UpSampling2D(
                interpolation=upsample_modes[i],
                data_format=keras_data_format)(deeper_endnodes[i + 1])

            output = (tf.keras.layers.Concatenate(
                axis=channels_axis)([skip_nodes[i], deeper_endnodes[i + 1]])
                      if skip_channels[i] != 0 else deeper_endnodes[i + 1])

            # Some final layers for deeper.
            output = tf.keras.layers.BatchNormalization(
                axis=channels_axis)(output)

            output = layers.ConvWithPad2D(
                up_channels[i],
                up_sizes[i],
                padding_mode=padding_mode,
                use_bias=use_bias,
                data_format=data_format,
            )(output)
            output = tf.keras.layers.BatchNormalization(
                axis=channels_axis)(output)
            output = activations()(output)

            if use_1x1up:
                output = layers.ConvWithPad2D(
                    up_channels[i],
                    1,
                    padding_mode=padding_mode,
                    use_bias=use_bias,
                    data_format=data_format,
                )(output)
                output = tf.keras.layers.BatchNormalization(
                    axis=channels_axis)(output)
                output = activations()(output)

            deeper_endnodes[i] = output

    # Final touches
    outputs = layers.ConvWithPad2D(
        output_channels,
        1,
        padding_mode=padding_mode,
        use_bias=use_bias,
        data_format=data_format,
    )(deeper_endnodes[0])
    if use_sigmoid:
        outputs = tf.keras.layers.Activation(tf.nn.sigmoid)(outputs)

    return tf.keras.Model(inputs=inputs, outputs=outputs)
예제 #7
0
 def psnr(y_true, y_pred):
     if utils.convert_data_format(data_format) == "NCHW":
         y_true = tf.transpose(y_true, [0, 2, 3, 1])
         y_pred = tf.transpose(y_pred, [0, 2, 3, 1])
     return tf.image.psnr(y_true, y_pred, max_val=1.0)