def feat_encoder(inp): global reuse_enc, bn_training with tf.variable_scope('encoder', reuse=reuse_enc): mod = Model(inp) mod.set_bn_training(bn_training) mod.convLayer(7, 16, stride=2, activation=M.PARAM_LRELU, batch_norm=True) #128 mod.maxpoolLayer(3, stride=2) # 64 mod.convLayer(5, 64, stride=2, activation=M.PARAM_LRELU, batch_norm=True) # 32 block(mod, 256, 2) # 16 block(mod, 256, 1) block(mod, 256, 1) mod.SelfAttention(32) block(mod, 256, 2) # 8 block(mod, 256, 1) block(mod, 256, 1) mod.SelfAttention(32) block(mod, 512, 2) #4 block(mod, 512, 1) block(mod, 512, 1) reuse_enc = True return mod.get_current_layer()
def generator(inp): global reuse_gen, bn_training with tf.variable_scope('generator', reuse=reuse_gen): mod = Model(inp) mod.set_bn_training(bn_training) mod.reshape([-1, 2, 2, 512]) mod.deconvLayer(3, 256, stride=2, activation=M.PARAM_LRELU, batch_norm=True) #8 mod.deconvLayer(3, 128, stride=2, activation=M.PARAM_LRELU, batch_norm=True) #16 mod.SelfAttention(32) mod.deconvLayer(5, 64, stride=2, activation=M.PARAM_LRELU, batch_norm=True) #32 mod.deconvLayer(5, 32, stride=2, activation=M.PARAM_LRELU, batch_norm=True) #64 mod.deconvLayer(5, 16, stride=2, activation=M.PARAM_LRELU, batch_norm=True) #128 mod.deconvLayer(5, 3, activation=M.PARAM_TANH) #output reuse_gen = True return mod.get_current_layer()
def discriminator(inp, age_size): global reuse_dis, bn_training, blknum blknum = 0 with tf.variable_scope('discriminator', reuse=reuse_dis): mod = Model(inp) mod.set_bn_training(bn_training) mod.convLayer(7, 16, stride=2, activation=M.PARAM_LRELU, batch_norm=True) # 64 mod.convLayer(5, 32, stride=2, activation=M.PARAM_LRELU, batch_norm=True) # 32 mod.SelfAttention(4) feat = mod.convLayer(5, 64, stride=2, activation=M.PARAM_LRELU) # 16 mod.batch_norm() mod.convLayer(3, 128, stride=2, activation=M.PARAM_LRELU, batch_norm=True) # 8 adv = mod.convLayer(3, 1) mod.set_current_layer(feat) block(mod, 128, 1) block(mod, 128, 2) # 8 block(mod, 256, 1) mod.SelfAttention(32) block(mod, 256, 1) block(mod, 256, 2) # 4 block(mod, 256, 1) mod.SelfAttention(32) block(mod, 256, 2) # 2 block(mod, 256, 1) mod.flatten() mod.fcLayer(512, activation=M.PARAM_LRELU) age = mod.fcLayer(age_size) reuse_dis = True return adv, age
def age_encoder(inp, ind): global reuse_age_enc name = 'decoder' + str(ind) if not name in reuse_age_enc: reuse = False else: reuse = True with tf.variable_scope(name, reuse=reuse): mod = Model(inp) mod.fcLayer(2 * 2 * 512, activation=M.PARAM_RELU) mod.SelfAttention(is_fc=True, residual=True) reuse_age_enc[name] = 1 return mod.get_current_layer()
def generator_att(inp): global reuse_genatt, bn_training, blknum blknum = 0 with tf.variable_scope('gen_att', reuse=reuse_genatt): mod = Model(inp) mod.set_bn_training(bn_training) mod.convLayer(5, 32, stride=2, activation=M.PARAM_LRELU, batch_norm=True) #64 block(mod, 64, 1) mod.convLayer(3, 128, stride=2, activation=M.PARAM_LRELU, batch_norm=True) #32 block(mod, 128, 1) # block(mod,256,1) # block(mod,256,1) block(mod, 256, 1) mod.SelfAttention(64, residual=True) # block(mod,512,1) block(mod, 256, 1) # block(mod,256,1) block(mod, 128, 1) mod.deconvLayer(3, 64, stride=2, activation=M.PARAM_LRELU, batch_norm=True) # 64 block(mod, 64, 1) feat = mod.deconvLayer(5, 64, stride=2, activation=M.PARAM_LRELU, batch_norm=True) #128 A = mod.convLayer(5, 1, activation=M.PARAM_SIGMOID) #output_attention mod.set_current_layer(feat) C = mod.convLayer(5, 3, activation=M.PARAM_TANH) reuse_genatt = True return A, C
def generator_att(inp): global reuse_genatt, bn_training with tf.variable_scope('gen_att', reuse=reuse_genatt): mod = Model(inp) mod.set_bn_training(bn_training) mod.deconvLayer(3, 512, stride=2, activation=M.PARAM_LRELU, batch_norm=True) #4 mod.deconvLayer(3, 256, stride=2, activation=M.PARAM_LRELU, batch_norm=True) #8 mod.deconvLayer(3, 128, stride=2, activation=M.PARAM_LRELU, batch_norm=True) #16 mod.SelfAttention(32) mod.deconvLayer(5, 64, stride=2, activation=M.PARAM_LRELU, batch_norm=True) #32 mod.deconvLayer(5, 32, stride=2, activation=M.PARAM_LRELU, batch_norm=True) #64 feat = mod.deconvLayer(5, 16, stride=2, activation=M.PARAM_LRELU, batch_norm=True) #128 A = mod.convLayer(5, 3, activation=M.PARAM_SIGMOID) #output_attention mod.set_current_layer(feat) C = mod.convLayer(5, 3, activation=M.PARAM_TANH) reuse_genatt = True return A, C