def __create_decoder(self, encoder_results, lastGlobalNet, output_channels): layer_specs = [ (self.ngf * 8, 0.5), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8] (self.ngf * 8, 0.5), # decoder_7: [batch, 2, 2, ngf * 8 ] => [batch, 4, 4, ngf * 8] (self.ngf * 8, 0.5), # decoder_6: [batch, 4, 4, ngf * 8 ] => [batch, 8, 8, ngf * 8] #Dropout was 0.5 until here (self.ngf * 8, 0.0), # decoder_5: [batch, 8, 8, ngf * 8 ] => [batch, 16, 16, ngf * 8] (self.ngf * 4, 0.0), # decoder_4: [batch, 16, 16, ngf * 8 ] => [batch, 32, 32, ngf * 4] (self.ngf * 2, 0.0), # decoder_3: [batch, 32, 32, ngf * 4] => [batch, 64, 64, ngf * 2] (self.ngf, 0.0), # decoder_2: [batch, 64, 64, ngf * 2] => [batch, 128, 128, ngf] ] decoder_results = [] num_encoder_layers = len(encoder_results) for decoder_layer, (out_channels, dropout) in enumerate(layer_specs): skip_layer = num_encoder_layers - decoder_layer - 1 with tf.variable_scope("decoder_%d" % (skip_layer + 1)): if decoder_layer == 0: # first decoder layer doesn't have skip connections # since it is directly connected to the skip_layer input = encoder_results[-1] else: input = tf.concat([decoder_results[-1], encoder_results[skip_layer]], axis=3) rectified = tfHelpers.lrelu(input, 0.2) # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels] output = tfHelpers.deconv(rectified, out_channels) output, mean, variance = tfHelpers.instancenorm(output) output, lastGlobalNet = self._addSecondaryNetBlock(output, mean, lastGlobalNet, out_channels, out_channels, num_encoder_layers + len(decoder_results)) if dropout > 0.0: output = tf.nn.dropout(output, keep_prob=1 - dropout) decoder_results.append(output) # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, output_channels] with tf.variable_scope("decoder_1"): input = tf.concat([decoder_results[-1], encoder_results[0]], axis=3) rectified = tfHelpers.lrelu(input, 0.2) deconved = tfHelpers.deconv(rectified, output_channels) #should we normalize it ? deconved, lastGlobalNet = self._addSecondaryNetBlock(deconved, None, lastGlobalNet, output_channels, output_channels, num_encoder_layers + len(decoder_results), True) #output = tf.tanh(deconved) decoder_results.append(deconved) return decoder_results[-1], lastGlobalNet
def create_generator(self, generator_inputs, generator_outputs_channels, reuse_bool=True): with tf.variable_scope("generator", reuse=reuse_bool) as scope: #Print the shape to check we are inputting a tensor with a reasonable shape print("generator_inputs :" + str(generator_inputs.get_shape())) print("generator_outputs_channels :" + str(generator_outputs_channels)) layers = [] #Input here should be [batch, 256,256,3] inputMean, inputVariance = tf.nn.moments(generator_inputs, axes=[1, 2], keep_dims=False) globalNetworkInput = inputMean globalNetworkOutputs = [] with tf.variable_scope("globalNetwork_fc_1"): globalNetwork_fc_1 = tfHelpers.fullyConnected( globalNetworkInput, self.ngf * 2, True, "globalNetworkLayer" + str(len(globalNetworkOutputs) + 1)) globalNetworkOutputs.append(tf.nn.selu(globalNetwork_fc_1)) #encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf] with tf.variable_scope("encoder_1"): #Convolution with stride 2 and kernel size 4x4. output = tfHelpers.conv(generator_inputs, self.ngf, stride=2) layers.append(output) #Default ngf is 64 layer_specs = [ self.ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2] self.ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4] self.ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8] self.ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8] self.ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8] self.ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8] #self.ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8] ] for layerCount, out_channels in enumerate(layer_specs): with tf.variable_scope("encoder_%d" % (len(layers) + 1)): #We use a leaky relu instead of a relu to let a bit more expressivity to the network. rectified = tfHelpers.lrelu(layers[-1], 0.2) # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels] convolved = tfHelpers.conv(rectified, out_channels, stride=2) #here mean and variance will be [batch, 1, 1, out_channels] and we run an instance normalization outputs, mean, variance = tfHelpers.instancenorm(convolved) #Get the last value in the global feature secondary network and transform it to be added to the current Unet layer output. outputs = outputs + tfHelpers.GlobalToGenerator( globalNetworkOutputs[-1], out_channels) with tf.variable_scope("globalNetwork_fc_%d" % (len(globalNetworkOutputs) + 1)): #Prepare the input to the next global feature secondary network step and run it. nextGlobalInput = tf.concat([ tf.expand_dims(tf.expand_dims( globalNetworkOutputs[-1], axis=1), axis=1), mean ], axis=-1) globalNetwork_fc = "" if layerCount + 1 < len(layer_specs) - 1: globalNetwork_fc = tfHelpers.fullyConnected( nextGlobalInput, layer_specs[layerCount + 1], True, "globalNetworkLayer" + str(len(globalNetworkOutputs) + 1)) else: globalNetwork_fc = tfHelpers.fullyConnected( nextGlobalInput, layer_specs[layerCount], True, "globalNetworkLayer" + str(len(globalNetworkOutputs) + 1)) #We use selu as we are in a fully connected network and it has auto normalization properties. globalNetworkOutputs.append( tf.nn.selu(globalNetwork_fc)) layers.append(outputs) with tf.variable_scope("encoder_8"): #The last encoder is mostly similar to previous layers except that we don't normalize the output. rectified = tfHelpers.lrelu(layers[-1], 0.2) # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels] convolvedNoGlobal = tfHelpers.conv(rectified, self.ngf * 8, stride=2) convolved = convolvedNoGlobal + tfHelpers.GlobalToGenerator( globalNetworkOutputs[-1], self.ngf * 8) with tf.variable_scope("globalNetwork_fc_%d" % (len(globalNetworkOutputs) + 1)): mean, variance = tf.nn.moments(convolvedNoGlobal, axes=[1, 2], keep_dims=True) nextGlobalInput = tf.concat([ tf.expand_dims(tf.expand_dims(globalNetworkOutputs[-1], axis=1), axis=1), mean ], axis=-1) globalNetwork_fc = tfHelpers.fullyConnected( nextGlobalInput, self.ngf * 8, True, "globalNetworkLayer" + str(len(globalNetworkOutputs) + 1)) globalNetworkOutputs.append(tf.nn.selu(globalNetwork_fc)) layers.append(convolved) #default nfg = 64 layer_specs = [ ( self.ngf * 8, 0.5 ), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2] ( self.ngf * 8, 0.5 ), # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2] ( self.ngf * 8, 0.5 ), # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2] ( self.ngf * 8, 0.0 ), # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2] ( self.ngf * 4, 0.0 ), # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2] ( self.ngf * 2, 0.0 ), # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2] ( self.ngf, 0.0 ), # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2] ] #Start the decoder here num_encoder_layers = len(layers) for decoder_layer, (out_channels, dropout) in enumerate(layer_specs): skip_layer = num_encoder_layers - decoder_layer - 1 #Evaluate which layer from the encoder has to be contatenated for the skip connection with tf.variable_scope("decoder_%d" % (skip_layer + 1)): if decoder_layer == 0: # first decoder layer doesn't have skip connections # since it is directly connected to the skip_layer input = layers[-1] else: input = tf.concat([layers[-1], layers[skip_layer]], axis=3) #Leaky relu some more (same reason as in the encoder) rectified = tfHelpers.lrelu(input, 0.2) # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels] #The deconvolution has stride 1 and shape 4x4. Theorically, it should be shape 3x3 to avoid any effects on the image borders, but it doesn't seem to have such a strong effect. output = tfHelpers.deconv(rectified, out_channels) #Instance norm and global feature secondary network similar to the decoder. output, mean, variance = tfHelpers.instancenorm(output) output = output + tfHelpers.GlobalToGenerator( globalNetworkOutputs[-1], out_channels) with tf.variable_scope("globalNetwork_fc_%d" % (len(globalNetworkOutputs) + 1)): nextGlobalInput = tf.concat([ tf.expand_dims(tf.expand_dims( globalNetworkOutputs[-1], axis=1), axis=1), mean ], axis=-1) globalNetwork_fc = tfHelpers.fullyConnected( nextGlobalInput, out_channels, True, "globalNetworkLayer" + str(len(globalNetworkOutputs) + 1)) globalNetworkOutputs.append( tf.nn.selu(globalNetwork_fc)) if dropout > 0.0: #We use dropout as described in the pix2pix paper. output = tf.nn.dropout(output, keep_prob=1 - dropout) layers.append(output) # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels] with tf.variable_scope("decoder_1"): input = tf.concat([layers[-1], layers[0]], axis=3) rectified = tfHelpers.lrelu(input, 0.2) output = tfHelpers.deconv(rectified, generator_outputs_channels) lastGlobalNet = tfHelpers.GlobalToGenerator( globalNetworkOutputs[-1], generator_outputs_channels) output = output + lastGlobalNet #output = tf.tanh(output) layers.append(output) return layers[-1], lastGlobalNet
def create_generator(self, generator_inputs, generator_outputs_channels, materialEncoded, reuse_bool=True): with tf.variable_scope("generator", reuse=reuse_bool) as scope: layers = [] #Input here should be [batch, 256,256,3] inputMean, inputVariance = tf.nn.moments(generator_inputs, axes=[1, 2], keep_dims=False) globalNetworkInput = inputMean globalNetworkOutputs = [] with tf.variable_scope("globalNetwork_fc_1"): globalNetwork_fc_1 = tfHelpers.fullyConnected( globalNetworkInput, self.ngf * 2, True, "globalNetworkLayer" + str(len(globalNetworkOutputs) + 1)) globalNetworkOutputs.append(tf.nn.selu(globalNetwork_fc_1)) #encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf] with tf.variable_scope("encoder_1"): output = tfHelpers.conv(generator_inputs, self.ngf, stride=2) layers.append(output) #Default ngf is 64 layer_specs = [ self.ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2] self.ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4] self.ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8] self.ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8] self.ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8] self.ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8] #self.ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8] ] for layerCount, out_channels in enumerate(layer_specs): with tf.variable_scope("encoder_%d" % (len(layers) + 1)): rectified = tfHelpers.lrelu(layers[-1], 0.2) # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels] convolved = tfHelpers.conv(rectified, out_channels, stride=2) #here mean and variance will be [batch, 1, 1, out_channels] outputs, mean, variance = tfHelpers.instancenorm(convolved) outputs = outputs + self.GlobalToGenerator( globalNetworkOutputs[-1], out_channels) with tf.variable_scope("globalNetwork_fc_%d" % (len(globalNetworkOutputs) + 1)): nextGlobalInput = tf.concat([ tf.expand_dims(tf.expand_dims( globalNetworkOutputs[-1], axis=1), axis=1), mean ], axis=-1) globalNetwork_fc = "" if layerCount + 1 < len(layer_specs) - 1: globalNetwork_fc = tfHelpers.fullyConnected( nextGlobalInput, layer_specs[layerCount + 1], True, "globalNetworkLayer" + str(len(globalNetworkOutputs) + 1)) else: globalNetwork_fc = tfHelpers.fullyConnected( nextGlobalInput, layer_specs[layerCount], True, "globalNetworkLayer" + str(len(globalNetworkOutputs) + 1)) globalNetworkOutputs.append( tf.nn.selu(globalNetwork_fc)) layers.append(outputs) with tf.variable_scope("encoder_8"): rectified = tfHelpers.lrelu(layers[-1], 0.2) # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels] convolved = tfHelpers.conv(rectified, self.ngf * 8, stride=2) convolved = convolved + self.GlobalToGenerator( globalNetworkOutputs[-1], self.ngf * 8) with tf.variable_scope("globalNetwork_fc_%d" % (len(globalNetworkOutputs) + 1)): mean, variance = tf.nn.moments(convolved, axes=[1, 2], keep_dims=True) nextGlobalInput = tf.concat([ tf.expand_dims(tf.expand_dims(globalNetworkOutputs[-1], axis=1), axis=1), mean ], axis=-1) globalNetwork_fc = tfHelpers.fullyConnected( nextGlobalInput, self.ngf * 8, True, "globalNetworkLayer" + str(len(globalNetworkOutputs) + 1)) globalNetworkOutputs.append(tf.nn.selu(globalNetwork_fc)) layers.append(convolved) #default nfg = 64 layer_specs = [ ( self.ngf * 8, 0.5 ), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2] ( self.ngf * 8, 0.5 ), # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2] ( self.ngf * 8, 0.5 ), # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2] ( self.ngf * 8, 0.0 ), # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2] ( self.ngf * 4, 0.0 ), # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2] ( self.ngf * 2, 0.0 ), # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2] ( self.ngf, 0.0 ), # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2] ] num_encoder_layers = len(layers) for decoder_layer, (out_channels, dropout) in enumerate(layer_specs): skip_layer = num_encoder_layers - decoder_layer - 1 with tf.variable_scope("decoder_%d" % (skip_layer + 1)): if decoder_layer == 0: # first decoder layer doesn't have skip connections # since it is directly connected to the skip_layer input = layers[-1] else: input = tf.concat([layers[-1], layers[skip_layer]], axis=3) rectified = tfHelpers.lrelu(input, 0.2) # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels] output = tfHelpers.deconv(rectified, out_channels) output, mean, variance = tfHelpers.instancenorm(output) output = output + self.GlobalToGenerator( globalNetworkOutputs[-1], out_channels) with tf.variable_scope("globalNetwork_fc_%d" % (len(globalNetworkOutputs) + 1)): nextGlobalInput = tf.concat([ tf.expand_dims(tf.expand_dims( globalNetworkOutputs[-1], axis=1), axis=1), mean ], axis=-1) globalNetwork_fc = tfHelpers.fullyConnected( nextGlobalInput, out_channels, True, "globalNetworkLayer" + str(len(globalNetworkOutputs) + 1)) globalNetworkOutputs.append( tf.nn.selu(globalNetwork_fc)) if dropout > 0.0: output = tf.nn.dropout(output, rate=dropout) layers.append(output) # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels] with tf.variable_scope("decoder_1"): input = tf.concat([layers[-1], layers[0]], axis=3) rectified = tfHelpers.lrelu(input, 0.2) output = tfHelpers.deconv(rectified, generator_outputs_channels) lastGlobalNet = self.GlobalToGenerator( globalNetworkOutputs[-1], generator_outputs_channels) output = output + lastGlobalNet output = tf.tanh(output) layers.append(output) return layers[-1], lastGlobalNet