def SPADE(block_number, batch_input, orig_input, k): with tf.variable_scope("SPADE_"+str(block_number)): layer1 = tf.image.resize_nearest_neighbor(orig_input,batch_input.shape[1:3]) layer2 = lrelu(conv_layer(2, layer1, 128, 1, 3, 0), 0) layer3 = conv_layer(3, layer2, k, 1, 3, 0) layer4 = conv_layer(4, layer2, k, 1, 3, 0) layer5 = inst_norm("I1", batch_input, 1E-5, 0.1, True, False) # paper says (sync) batch norm, github says instance norm return layer5*layer3 + layer4
def SPADE_ResBLK(block_number, batch_input, orig_input, k): mid = min(batch_input.shape[-1], k) with tf.variable_scope("SPADE_ResBLK_" + str(block_number)): layer1 = spec_norm('S1', conv_layer(1, lrelu(SPADE(1, batch_input, orig_input, batch_input.shape[-1]), 0.2), mid, 1, 3, 0)) layer2 = spec_norm('S2', conv_layer(2, lrelu(SPADE(2, layer1, orig_input, mid), 0.2), k, 1, 3, 0)) layer3 = spec_norm('S3', conv_layer(3, lrelu(SPADE(3, batch_input, orig_input, batch_input.shape[-1]), 0.2), k, 1, 3, 0)) if batch_input.shape[-1] != k: return layer2 + layer3 else: return layer2 + batch_input
def generator(gen_input, z): with tf.variable_scope("generator"): layer1 = linear(1, z, 16384) layer2 = tf.reshape(layer1, [1, 4, 4, 1024]) layer3 = tf.image.resize_nearest_neighbor( SPADE_ResBLK(3, layer2, gen_input, 1024), [i * 2 for i in layer2.shape[1:3]]) layer4 = tf.image.resize_nearest_neighbor( SPADE_ResBLK(4, layer3, gen_input, 1024), [i * 2 for i in layer3.shape[1:3]]) layer5 = tf.image.resize_nearest_neighbor( SPADE_ResBLK(5, layer4, gen_input, 1024), [i * 2 for i in layer4.shape[1:3]]) layer6 = tf.image.resize_nearest_neighbor( SPADE_ResBLK(6, layer5, gen_input, 512), [i * 2 for i in layer5.shape[1:3]]) layer7 = tf.image.resize_nearest_neighbor( SPADE_ResBLK(7, layer6, gen_input, 256), [i * 2 for i in layer6.shape[1:3]]) layer8 = tf.image.resize_nearest_neighbor( SPADE_ResBLK(8, layer7, gen_input, 128), [i * 2 for i in layer7.shape[1:3]]) layer9 = tf.image.resize_nearest_neighbor( SPADE_ResBLK(9, layer8, gen_input, 64), [i * 2 for i in layer8.shape[1:3]]) layer10 = tf.math.tanh(conv_layer(10, lrelu(layer9, 0.2), 3, 1, 3, 0)) return layer10
def discriminator(gen_input, gen_out_or_targets, factor, reuse): with tf.variable_scope("discriminator_" + str(factor), reuse=reuse): inputs = tf.concat([gen_input, gen_out_or_targets], axis=3) layer1 = lrelu(conv_layer(1, inputs, 64, 2, 4, 0.2), 0.2) layer2 = lrelu( inst_norm('I2', spec_norm('S1', conv_layer(2, layer1, 128, 2, 4, 0.2)), 1E-5, 0.1, True, False), 0.2) layer3 = lrelu( inst_norm('I3', spec_norm('S2', conv_layer(3, layer2, 256, 2, 4, 0.2)), 1E-5, 0.1, True, False), 0.2) layer4 = lrelu( inst_norm('I4', spec_norm('S3', conv_layer(4, layer3, 512, 1, 4, 0.2)), 1E-5, 0.1, True, False), 0.2) layer5 = conv_layer(5, layer4, 1, 1, 4, 0) return [layer1, layer2, layer3, layer4, layer5]
def encoder(encoder_input): with tf.variable_scope("encoder"): layer1 = lrelu( inst_norm( 'I1', spec_norm('S1', conv_layer(1, encoder_input, 64, 2, 3, 0.2)), 1E-5, 0.1, True, False), 0.2) layer2 = lrelu( inst_norm('I2', spec_norm('S2', conv_layer(2, layer1, 128, 2, 3, 0.2)), 1E-5, 0.1, True, False), 0.2) layer3 = lrelu( inst_norm('I3', spec_norm('S3', conv_layer(3, layer2, 256, 2, 3, 0.2)), 1E-5, 0.1, True, False), 0.2) layer4 = lrelu( inst_norm('I4', spec_norm('S4', conv_layer(4, layer3, 512, 2, 3, 0.2)), 1E-5, 0.1, True, False), 0.2) layer5 = lrelu( inst_norm('I5', spec_norm('S5', conv_layer(5, layer4, 512, 2, 3, 0.2)), 1E-5, 0.1, True, False), 0.2) layer6 = lrelu( inst_norm('I6', spec_norm('S6', conv_layer(6, layer5, 512, 2, 3, 0.2)), 1E-5, 0.1, True, False), 0.2) layer7 = tf.reshape(layer6, [8192, 1, 1]) mu = linear(1, layer7, 256) # mu logvar = linear(2, layer7, 256) # logvar sigma = tf.math.exp(0.5 * logvar) rand = tf.random.normal(sigma.shape, mean=0.0, stddev=1.0) return mu, logvar, sigma * rand + mu