def _addSecondaryNetBlock(self,
                              input,
                              inputMean,
                              lastGlobalNetworkValue,
                              currentChannels,
                              nextChannels,
                              layerCount,
                              keep_dims=False):
        if self.useSecondary:
            if inputMean is None:
                inputMean, _ = tf.nn.moments(input,
                                             axes=[1, 2],
                                             keep_dims=keep_dims)
            summed = input
            if not lastGlobalNetworkValue is None:
                summed = input + tfHelpers.GlobalToGenerator(
                    lastGlobalNetworkValue, currentChannels)
            with tf.variable_scope("globalNetwork_fc_%d" % (layerCount + 1)):
                nextGlobalInput = inputMean
                if not lastGlobalNetworkValue is None:
                    nextGlobalInput = tf.concat([
                        tf.expand_dims(tf.expand_dims(lastGlobalNetworkValue,
                                                      axis=1),
                                       axis=1), inputMean
                    ],
                                                axis=-1)
                globalNetwork_fc = tfHelpers.fullyConnected(
                    nextGlobalInput, nextChannels, True,
                    "globalNetworkLayer" + str(layerCount + 1))

            return summed, tf.nn.selu(
                globalNetwork_fc
            )  #returns the sum of this layer + last globalNet output and a new globalNetValue
        else:
            return input, None
Beispiel #2
0
    def create_generator(self,
                         generator_inputs,
                         generator_outputs_channels,
                         materialEncoded,
                         reuse_bool=True):
        with tf.variable_scope("generator", reuse=reuse_bool) as scope:
            layers = []
            #Input here should be [batch, 256,256,3]
            inputMean, inputVariance = tf.nn.moments(generator_inputs,
                                                     axes=[1, 2],
                                                     keep_dims=False)
            globalNetworkInput = inputMean
            globalNetworkOutputs = []

            with tf.variable_scope("globalNetwork_fc_1"):
                globalNetwork_fc_1 = tfHelpers.fullyConnected(
                    globalNetworkInput, self.ngf * 2, True,
                    "globalNetworkLayer" + str(len(globalNetworkOutputs) + 1))
                globalNetworkOutputs.append(tf.nn.selu(globalNetwork_fc_1))

            #encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
            with tf.variable_scope("encoder_1"):
                output = tfHelpers.conv(generator_inputs, self.ngf, stride=2)
                layers.append(output)
            #Default ngf is 64
            layer_specs = [
                self.ngf *
                2,  # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
                self.ngf *
                4,  # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
                self.ngf *
                8,  # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
                self.ngf *
                8,  # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
                self.ngf *
                8,  # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
                self.ngf *
                8,  # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
                #self.ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
            ]

            for layerCount, out_channels in enumerate(layer_specs):
                with tf.variable_scope("encoder_%d" % (len(layers) + 1)):
                    rectified = tfHelpers.lrelu(layers[-1], 0.2)
                    # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
                    convolved = tfHelpers.conv(rectified,
                                               out_channels,
                                               stride=2)
                    #here mean and variance will be [batch, 1, 1, out_channels]
                    outputs, mean, variance = tfHelpers.instancenorm(convolved)

                    outputs = outputs + self.GlobalToGenerator(
                        globalNetworkOutputs[-1], out_channels)
                    with tf.variable_scope("globalNetwork_fc_%d" %
                                           (len(globalNetworkOutputs) + 1)):
                        nextGlobalInput = tf.concat([
                            tf.expand_dims(tf.expand_dims(
                                globalNetworkOutputs[-1], axis=1),
                                           axis=1), mean
                        ],
                                                    axis=-1)
                        globalNetwork_fc = ""
                        if layerCount + 1 < len(layer_specs) - 1:
                            globalNetwork_fc = tfHelpers.fullyConnected(
                                nextGlobalInput, layer_specs[layerCount + 1],
                                True, "globalNetworkLayer" +
                                str(len(globalNetworkOutputs) + 1))
                        else:
                            globalNetwork_fc = tfHelpers.fullyConnected(
                                nextGlobalInput, layer_specs[layerCount], True,
                                "globalNetworkLayer" +
                                str(len(globalNetworkOutputs) + 1))

                        globalNetworkOutputs.append(
                            tf.nn.selu(globalNetwork_fc))
                    layers.append(outputs)

            with tf.variable_scope("encoder_8"):
                rectified = tfHelpers.lrelu(layers[-1], 0.2)
                # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
                convolved = tfHelpers.conv(rectified, self.ngf * 8, stride=2)
                convolved = convolved + self.GlobalToGenerator(
                    globalNetworkOutputs[-1], self.ngf * 8)

                with tf.variable_scope("globalNetwork_fc_%d" %
                                       (len(globalNetworkOutputs) + 1)):
                    mean, variance = tf.nn.moments(convolved,
                                                   axes=[1, 2],
                                                   keep_dims=True)
                    nextGlobalInput = tf.concat([
                        tf.expand_dims(tf.expand_dims(globalNetworkOutputs[-1],
                                                      axis=1),
                                       axis=1), mean
                    ],
                                                axis=-1)
                    globalNetwork_fc = tfHelpers.fullyConnected(
                        nextGlobalInput, self.ngf * 8, True,
                        "globalNetworkLayer" +
                        str(len(globalNetworkOutputs) + 1))
                    globalNetworkOutputs.append(tf.nn.selu(globalNetwork_fc))

                layers.append(convolved)
            #default nfg = 64
            layer_specs = [
                (
                    self.ngf * 8, 0.5
                ),  # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2]
                (
                    self.ngf * 8, 0.5
                ),  # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2]
                (
                    self.ngf * 8, 0.5
                ),  # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
                (
                    self.ngf * 8, 0.0
                ),  # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2]
                (
                    self.ngf * 4, 0.0
                ),  # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2]
                (
                    self.ngf * 2, 0.0
                ),  # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2]
                (
                    self.ngf, 0.0
                ),  # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2]
            ]

            num_encoder_layers = len(layers)
            for decoder_layer, (out_channels,
                                dropout) in enumerate(layer_specs):
                skip_layer = num_encoder_layers - decoder_layer - 1
                with tf.variable_scope("decoder_%d" % (skip_layer + 1)):
                    if decoder_layer == 0:
                        # first decoder layer doesn't have skip connections
                        # since it is directly connected to the skip_layer
                        input = layers[-1]
                    else:
                        input = tf.concat([layers[-1], layers[skip_layer]],
                                          axis=3)

                    rectified = tfHelpers.lrelu(input, 0.2)
                    # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]
                    output = tfHelpers.deconv(rectified, out_channels)
                    output, mean, variance = tfHelpers.instancenorm(output)
                    output = output + self.GlobalToGenerator(
                        globalNetworkOutputs[-1], out_channels)
                    with tf.variable_scope("globalNetwork_fc_%d" %
                                           (len(globalNetworkOutputs) + 1)):
                        nextGlobalInput = tf.concat([
                            tf.expand_dims(tf.expand_dims(
                                globalNetworkOutputs[-1], axis=1),
                                           axis=1), mean
                        ],
                                                    axis=-1)
                        globalNetwork_fc = tfHelpers.fullyConnected(
                            nextGlobalInput, out_channels, True,
                            "globalNetworkLayer" +
                            str(len(globalNetworkOutputs) + 1))
                        globalNetworkOutputs.append(
                            tf.nn.selu(globalNetwork_fc))
                    if dropout > 0.0:
                        output = tf.nn.dropout(output, rate=dropout)

                    layers.append(output)

            # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels]
            with tf.variable_scope("decoder_1"):
                input = tf.concat([layers[-1], layers[0]], axis=3)
                rectified = tfHelpers.lrelu(input, 0.2)
                output = tfHelpers.deconv(rectified,
                                          generator_outputs_channels)
                lastGlobalNet = self.GlobalToGenerator(
                    globalNetworkOutputs[-1], generator_outputs_channels)
                output = output + lastGlobalNet
                output = tf.tanh(output)
                layers.append(output)

            return layers[-1], lastGlobalNet
    def create_generator(self,
                         generator_inputs,
                         generator_outputs_channels,
                         reuse_bool=True):
        with tf.variable_scope("generator", reuse=reuse_bool) as scope:
            #Print the shape to check we are inputting a tensor with a reasonable shape
            print("generator_inputs :" + str(generator_inputs.get_shape()))
            print("generator_outputs_channels :" +
                  str(generator_outputs_channels))
            layers = []
            #Input here should be [batch, 256,256,3]
            inputMean, inputVariance = tf.nn.moments(generator_inputs,
                                                     axes=[1, 2],
                                                     keep_dims=False)
            globalNetworkInput = inputMean
            globalNetworkOutputs = []
            with tf.variable_scope("globalNetwork_fc_1"):
                globalNetwork_fc_1 = tfHelpers.fullyConnected(
                    globalNetworkInput, self.ngf * 2, True,
                    "globalNetworkLayer" + str(len(globalNetworkOutputs) + 1))
                globalNetworkOutputs.append(tf.nn.selu(globalNetwork_fc_1))

            #encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
            with tf.variable_scope("encoder_1"):
                #Convolution with stride 2 and kernel size 4x4.
                output = tfHelpers.conv(generator_inputs, self.ngf, stride=2)
                layers.append(output)
            #Default ngf is 64
            layer_specs = [
                self.ngf *
                2,  # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
                self.ngf *
                4,  # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
                self.ngf *
                8,  # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
                self.ngf *
                8,  # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
                self.ngf *
                8,  # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
                self.ngf *
                8,  # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
                #self.ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
            ]

            for layerCount, out_channels in enumerate(layer_specs):
                with tf.variable_scope("encoder_%d" % (len(layers) + 1)):
                    #We use a leaky relu instead of a relu to let a bit more expressivity to the network.
                    rectified = tfHelpers.lrelu(layers[-1], 0.2)
                    # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
                    convolved = tfHelpers.conv(rectified,
                                               out_channels,
                                               stride=2)
                    #here mean and variance will be [batch, 1, 1, out_channels] and we run an instance normalization
                    outputs, mean, variance = tfHelpers.instancenorm(convolved)

                    #Get the last value in the global feature secondary network and transform it to be added to the current Unet layer output.
                    outputs = outputs + tfHelpers.GlobalToGenerator(
                        globalNetworkOutputs[-1], out_channels)
                    with tf.variable_scope("globalNetwork_fc_%d" %
                                           (len(globalNetworkOutputs) + 1)):
                        #Prepare the input to the next global feature secondary network step and run it.
                        nextGlobalInput = tf.concat([
                            tf.expand_dims(tf.expand_dims(
                                globalNetworkOutputs[-1], axis=1),
                                           axis=1), mean
                        ],
                                                    axis=-1)
                        globalNetwork_fc = ""
                        if layerCount + 1 < len(layer_specs) - 1:
                            globalNetwork_fc = tfHelpers.fullyConnected(
                                nextGlobalInput, layer_specs[layerCount + 1],
                                True, "globalNetworkLayer" +
                                str(len(globalNetworkOutputs) + 1))
                        else:
                            globalNetwork_fc = tfHelpers.fullyConnected(
                                nextGlobalInput, layer_specs[layerCount], True,
                                "globalNetworkLayer" +
                                str(len(globalNetworkOutputs) + 1))
                        #We use selu as we are in a fully connected network and it has auto normalization properties.
                        globalNetworkOutputs.append(
                            tf.nn.selu(globalNetwork_fc))
                    layers.append(outputs)

            with tf.variable_scope("encoder_8"):
                #The last encoder is mostly similar to previous layers except that we don't normalize the output.
                rectified = tfHelpers.lrelu(layers[-1], 0.2)
                # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
                convolvedNoGlobal = tfHelpers.conv(rectified,
                                                   self.ngf * 8,
                                                   stride=2)
                convolved = convolvedNoGlobal + tfHelpers.GlobalToGenerator(
                    globalNetworkOutputs[-1], self.ngf * 8)

                with tf.variable_scope("globalNetwork_fc_%d" %
                                       (len(globalNetworkOutputs) + 1)):
                    mean, variance = tf.nn.moments(convolvedNoGlobal,
                                                   axes=[1, 2],
                                                   keep_dims=True)
                    nextGlobalInput = tf.concat([
                        tf.expand_dims(tf.expand_dims(globalNetworkOutputs[-1],
                                                      axis=1),
                                       axis=1), mean
                    ],
                                                axis=-1)
                    globalNetwork_fc = tfHelpers.fullyConnected(
                        nextGlobalInput, self.ngf * 8, True,
                        "globalNetworkLayer" +
                        str(len(globalNetworkOutputs) + 1))
                    globalNetworkOutputs.append(tf.nn.selu(globalNetwork_fc))

                layers.append(convolved)
            #default nfg = 64
            layer_specs = [
                (
                    self.ngf * 8, 0.5
                ),  # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2]
                (
                    self.ngf * 8, 0.5
                ),  # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2]
                (
                    self.ngf * 8, 0.5
                ),  # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
                (
                    self.ngf * 8, 0.0
                ),  # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2]
                (
                    self.ngf * 4, 0.0
                ),  # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2]
                (
                    self.ngf * 2, 0.0
                ),  # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2]
                (
                    self.ngf, 0.0
                ),  # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2]
            ]
            #Start the decoder here
            num_encoder_layers = len(layers)
            for decoder_layer, (out_channels,
                                dropout) in enumerate(layer_specs):
                skip_layer = num_encoder_layers - decoder_layer - 1
                #Evaluate which layer from the encoder has to be contatenated for the skip connection
                with tf.variable_scope("decoder_%d" % (skip_layer + 1)):
                    if decoder_layer == 0:
                        # first decoder layer doesn't have skip connections
                        # since it is directly connected to the skip_layer
                        input = layers[-1]
                    else:
                        input = tf.concat([layers[-1], layers[skip_layer]],
                                          axis=3)

                    #Leaky relu some more (same reason as in the encoder)
                    rectified = tfHelpers.lrelu(input, 0.2)
                    # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]

                    #The deconvolution has stride 1 and shape 4x4. Theorically, it should be shape 3x3 to avoid any effects on the image borders, but it doesn't seem to have such a strong effect.
                    output = tfHelpers.deconv(rectified, out_channels)

                    #Instance norm and global feature secondary network similar to the decoder.
                    output, mean, variance = tfHelpers.instancenorm(output)
                    output = output + tfHelpers.GlobalToGenerator(
                        globalNetworkOutputs[-1], out_channels)
                    with tf.variable_scope("globalNetwork_fc_%d" %
                                           (len(globalNetworkOutputs) + 1)):
                        nextGlobalInput = tf.concat([
                            tf.expand_dims(tf.expand_dims(
                                globalNetworkOutputs[-1], axis=1),
                                           axis=1), mean
                        ],
                                                    axis=-1)
                        globalNetwork_fc = tfHelpers.fullyConnected(
                            nextGlobalInput, out_channels, True,
                            "globalNetworkLayer" +
                            str(len(globalNetworkOutputs) + 1))
                        globalNetworkOutputs.append(
                            tf.nn.selu(globalNetwork_fc))
                    if dropout > 0.0:
                        #We use dropout as described in the pix2pix paper.
                        output = tf.nn.dropout(output, keep_prob=1 - dropout)

                    layers.append(output)

            # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels]
            with tf.variable_scope("decoder_1"):
                input = tf.concat([layers[-1], layers[0]], axis=3)
                rectified = tfHelpers.lrelu(input, 0.2)
                output = tfHelpers.deconv(rectified,
                                          generator_outputs_channels)
                lastGlobalNet = tfHelpers.GlobalToGenerator(
                    globalNetworkOutputs[-1], generator_outputs_channels)
                output = output + lastGlobalNet
                #output = tf.tanh(output)
                layers.append(output)

            return layers[-1], lastGlobalNet
Beispiel #4
0
 def GlobalToGenerator(self, inputs, channels):
     with tf.variable_scope("GlobalToGenerator1"):
         fc1 = tfHelpers.fullyConnected(inputs, channels, False,
                                        "fc_global_to_unet", 0.01)
     return tf.expand_dims(tf.expand_dims(fc1, axis=1), axis=1)