def get_Ec(x_shape=(None, flags.img_size_h, flags.img_size_w, flags.c_dim), name=None): # ref: Multimodal Unsupervised Image-to-Image Translation lrelu = lambda x: tl.act.lrelu(x, 0.01) w_init = tf.random_normal_initializer(stddev=0.02) channel = 64 ni = Input(x_shape) n = Conv2d(channel, (7, 7), (1, 1), act=lrelu, W_init=w_init)(ni) for i in range(2): n = Conv2d(channel * 2, (3, 3), (2, 2), W_init=w_init)(n) n = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(n) channel = channel * 2 for i in range(1, 5): # res block nn = Conv2d(channel, (3, 3), (1, 1), act=None, W_init=w_init, b_init=None)(n) nn = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(nn) nn = Conv2d(channel, (3, 3), (1, 1), act=None, W_init=w_init, b_init=None)(nn) nn = InstanceNorm2d(act=None, gamma_init=g_init)(nn) n = Elementwise(tf.add)([n, nn]) n = GaussianNoise(is_always=False)(n) M = Model(inputs=ni, outputs=n, name=name) return M
def build_resnet_block_Att(inputres, dim, name="resnet", padding="REFLECT"): with tf.compat.v1.variable_scope(name): out_res = PadLayer([[0, 0], [1, 1], [1, 1], [0, 0]], padding)(inputres) out_res = Conv2d(n_filter=dim, filter_size=(3, 3), strides=(1, 1), padding="VALID", act=None, W_init=tf.initializers.TruncatedNormal(stddev=0.02), b_init=tf.constant_initializer(0.0))(out_res) out_res = InstanceNorm2d(act=tf.nn.relu)(out_res) out_res = PadLayer([[0, 0], [1, 1], [1, 1], [0, 0]], padding)(out_res) out_res = Conv2d(n_filter=dim, filter_size=(3, 3), strides=(1, 1), padding="VALID", act=None, W_init=tf.initializers.TruncatedNormal(stddev=0.02), b_init=tf.constant_initializer(0.0))(out_res) out_res = InstanceNorm2d(act=None)(out_res) tmp = Elementwise(combine_fn=tf.add)([out_res, inputres]) return Lambda(tf.nn.relu)(tmp)
def get_G(a_shape=(None, flags.za_dim), c_shape=(None, flags.c_shape[0], flags.c_shape[1], flags.c_shape[2]), \ name=None): ndf = 256 na = Input(a_shape) nc = Input(c_shape) #z = Concat(-1)([na, nt]) z = na nz = ExpandDims(1)(z) nz = ExpandDims(1)(nz) nz = Tile([1, c_shape[1], c_shape[2], 1])(nz) # res block nn = Conv2d(ndf, (3, 3), (1, 1), act=None, W_init=w_init, b_init=None)(nc) nn = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(nn) nn = Conv2d(ndf, (3, 3), (1, 1), act=None, W_init=w_init, b_init=None)(nn) nn = InstanceNorm2d(act=None, gamma_init=g_init)(nn) n = Elementwise(tf.add)([nc, nn]) nd_tmp = flags.za_dim ndf = ndf + nd_tmp n = Concat(-1)([n, nz]) # res block *3 for i in range(1, 4): nn = Conv2d(ndf, (3, 3), (1, 1), act=None, W_init=w_init, b_init=None)(n) nn = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(nn) nn = Conv2d(ndf, (3, 3), (1, 1), act=None, W_init=w_init, b_init=None)(nn) nn = InstanceNorm2d(act=None, gamma_init=g_init)(nn) n = Elementwise(tf.add)([n, nn]) for i in range(2): ndf = ndf + nd_tmp n = Concat(-1)([n, nz]) nz = Tile([1, 2, 2, 1])(nz) n = DeConv2d(ndf // 2, (3, 3), (2, 2), act=tf.nn.relu, W_init=w_init, b_init=None)(n) n = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(n) ndf = ndf // 2 n = Concat(-1)([n, nz]) n = DeConv2d(3, (1, 1), (1, 1), act=tf.nn.tanh, W_init=w_init)(n) M = Model(inputs=[na, nc], outputs=n, name=name) return M
def get_D_content(c_shape=(None, flags.c_shape[0], flags.c_shape[1], flags.c_shape[2])): # reference: DRIT resource code -- Pytorch implementation ni = Input(c_shape) n = Conv2d(256, (7, 7), (2, 2), act=None, W_init=w_init)(ni) n = InstanceNorm2d(act=lrelu, gamma_init=g_init)(n) n = Conv2d(256, (7, 7), (2, 2), act=None, W_init=w_init)(n) n = InstanceNorm2d(act=lrelu, gamma_init=g_init)(n) n = Conv2d(256, (7, 7), (2, 2), act=None, W_init=w_init)(n) n = InstanceNorm2d(act=lrelu, gamma_init=g_init)(n) n = Conv2d(256, (4, 4), (1, 1), act=None, padding='VALID', W_init=w_init)(n) n = InstanceNorm2d(act=lrelu, gamma_init=g_init)(n) n = Conv2d(1, (5, 5), (5, 5), padding='VALID', W_init=w_init)(n) n = Reshape(shape=[-1, 1])(n) return tl.models.Model(inputs=ni, outputs=n, name=None)
def get_G(name=None): gf_dim = 32 w_init = tf.random_normal_initializer(stddev=0.02) nx = Input((flags.batch_size, 256, 256, 3)) n = Conv2d(gf_dim, (7, 7), (1, 1), W_init=w_init)(nx) n = InstanceNorm2d(act=tf.nn.relu)(n) n = Conv2d(gf_dim * 2, (3, 3), (2, 2), W_init=w_init)(n) n = InstanceNorm2d(act=tf.nn.relu)(n) n = Conv2d(gf_dim * 4, (3, 3), (2, 2), W_init=w_init)(n) n = InstanceNorm2d(act=tf.nn.relu)(n) for i in range(9): _n = Conv2d(gf_dim * 4, (3, 3), (1, 1), W_init=w_init)(n) _n = InstanceNorm2d(act=tf.nn.relu)(_n) _n = Conv2d(gf_dim * 4, (3, 3), (1, 1), W_init=w_init)(_n) _n = InstanceNorm2d()(_n) _n = Elementwise(tf.add)([n, _n]) n = _n n = DeConv2d(gf_dim * 2, (3, 3), (2, 2), W_init=w_init)(n) n = InstanceNorm2d(act=tf.nn.relu)(n) n = DeConv2d(gf_dim, (3, 3), (2, 2), W_init=w_init)(n) n = InstanceNorm2d(act=tf.nn.relu)(n) n = Conv2d(3, (7, 7), (1, 1), act=tf.nn.tanh, W_init=w_init)(n) M = Model(inputs=nx, outputs=n, name=name) return M
def get_G_zc(shape_z=(None, flags.zc_dim), gf_dim=64): # reference: DCGAN generator output_size = 64 s16 = output_size // 16 ni = Input(shape_z) nn = Dense(n_units=(gf_dim * 8 * s16 * s16), W_init=w_init, b_init=None)(ni) nn = Reshape(shape=[-1, s16, s16, gf_dim * 8])(nn) nn = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(nn) nn = DeConv2d(gf_dim * 4, (5, 5), (2, 2), W_init=w_init, b_init=None)(nn) nn = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(nn) nn = DeConv2d(gf_dim * 2, (5, 5), (2, 2), W_init=w_init, b_init=None)(nn) nn = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(nn) nn = DeConv2d(gf_dim, (5, 5), (2, 2), W_init=w_init, b_init=None)(nn) nn = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(nn) nn = DeConv2d(256, (5, 5), (2, 2), act=tf.nn.tanh, W_init=w_init)(nn) return tl.models.Model(inputs=ni, outputs=nn, name='Generator_zc')
def get_patch_D(input_shape, name="discriminator"): w_init = tf.random_normal_initializer(stddev=0.02) gamma_init = tf.random_normal_initializer(1., 0.02) df_dim = 64 lrelu = lambda x: tl.act.lrelu(x, 0.2) outputs = [] nin = Input(input_shape) n = Conv2d(df_dim, (4, 4), (2, 2), act=lrelu, padding='SAME', W_init=w_init)(nin) outputs.append(n) n = Conv2d(df_dim * 2, (4, 4), (2, 2), padding='SAME', W_init=w_init, b_init=None)(n) outputs.append(n) n = InstanceNorm2d(act=lrelu, gamma_init=gamma_init)(n) n = Conv2d(df_dim * 4, (4, 4), (2, 2), padding='SAME', W_init=w_init, b_init=None)(n) outputs.append(n) n = InstanceNorm2d(act=lrelu, gamma_init=gamma_init)(n) n = Conv2d(df_dim * 8, (4, 4), (2, 2), padding='SAME', W_init=w_init, b_init=None)(n) outputs.append(n) n = InstanceNorm2d(act=lrelu, gamma_init=gamma_init)(n) n = Conv2d(1, (1, 1), (1, 1), padding='SAME', W_init=w_init, b_init=None)(n) outputs.append(n) D = Model(inputs=nin, outputs=outputs, name=name) return D
def get_D(name=None): df_dim = 64 w_init = tf.random_normal_initializer(stddev=0.02) lrelu = lambda x: tl.act.lrelu(x, 0.2) nx = Input((flags.batch_size, 256, 256, 3)) n = Lambda(lambda x: tf.image.random_crop(x, [flags.batch_size, 70, 70, 3]))(nx) # patchGAN n = Conv2d(df_dim, (4, 4), (2, 2), act=lrelu, W_init=w_init)(n) n = Conv2d(df_dim * 2, (4, 4), (2, 2), W_init=w_init)(n) n = InstanceNorm2d(act=lrelu)(n) n = Conv2d(df_dim * 4, (4, 4), (2, 2), W_init=w_init)(n) n = InstanceNorm2d(act=lrelu)(n) n = Conv2d(df_dim * 8, (4, 4), (2, 2), W_init=w_init)(n) n = InstanceNorm2d(act=lrelu)(n) n = Conv2d(1, (4, 4), (4, 4), padding='VALID', W_init=w_init)(n) n = Flatten()(n) assert n.shape[-1] == 1 M = Model(inputs=nx, outputs=n, name=name) return M
def cycle_G(input_shape, name="dx_generator"): w_init = tf.random_normal_initializer(stddev=0.02) g_init = tf.random_normal_initializer(1., 0.02) nin = Input(input_shape) n = Conv2d(64, (7, 7), (1, 1), padding='SAME', W_init=w_init)(nin) n1 = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(n) n2 = Conv2d(128, (3, 3), (2, 2), padding='SAME', W_init=w_init)(n1) n2 = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(n2) n = Conv2d(256, (3, 3), (2, 2), padding='SAME', W_init=w_init)(n2) n = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(n) # B residual blocks for i in range(9): nn = Conv2d(256, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=None)(n) nn = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(nn) nn = Conv2d(256, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=None)(nn) nn = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(nn) nn = Elementwise(tf.add)([n, nn]) n = nn n = DeConv2d(n_filter=128)(n) n = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(n) n = DeConv2d(n_filter=64)(n) n = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(n) nn = Conv2d(1, (7, 7), (1, 1), act=tf.nn.tanh, padding='SAME', W_init=w_init)(n) G = Model(inputs=nin, outputs=nn, name=name) return G
def discriminator(name="discriminator"): with tf.compat.v1.variable_scope(name): inputdisc_in = Input(shape=[None, IMG_WIDTH, IMG_HEIGHT, IMG_CHANNELS], dtype=tf.float32) mask_in = Input(shape=[None, IMG_WIDTH, IMG_HEIGHT, IMG_CHANNELS], dtype=tf.float32) transition_rate = Input(shape=[1], dtype=tf.float32) donorm = Input(shape=[1], dtype=tf.float32) tmp = Elementwise(combine_fn=tf.greater_equal)( [mask_in, transition_rate]) mask = Lambda(fn=my_cast)(tmp) inputdisc = Elementwise(combine_fn=tf.multiply)([inputdisc_in, mask]) f = 4 padw = 2 lrelu = lambda x: tl.act.lrelu(x, 0.2) pad_input = PadLayer([[0, 0], [padw, padw], [padw, padw], [0, 0]], "CONSTANT")(inputdisc) o_c1 = Conv2d(n_filter=ndf, filter_size=(f, f), strides=(2, 2), padding="VALID", act=None, W_init=tf.initializers.TruncatedNormal(stddev=0.02), b_init=tf.constant_initializer(0.0))(pad_input) #pdb.set_trace() o_c1 = Lambda(fn=my_cond)( [donorm, InstanceNorm2d(act=None)(o_c1), o_c1]) o_c1 = Lambda(fn=lrelu)(o_c1) pad_o_c1 = PadLayer([[0, 0], [padw, padw], [padw, padw], [0, 0]], "CONSTANT")(o_c1) o_c2 = Conv2d(n_filter=ndf * 2, filter_size=(f, f), strides=(2, 2), padding="VALID", act=None, W_init=tf.initializers.TruncatedNormal(stddev=0.02), b_init=tf.constant_initializer(0.0))(pad_o_c1) o_c2 = Lambda(fn=my_cond)( [donorm, InstanceNorm2d(act=None)(o_c2), o_c2]) o_c2 = Lambda(fn=lrelu)(o_c2) pad_o_c2 = PadLayer([[0, 0], [padw, padw], [padw, padw], [0, 0]], "CONSTANT")(o_c2) o_c3 = Conv2d(n_filter=ndf * 4, filter_size=(f, f), strides=(2, 2), padding="VALID", act=None, W_init=tf.initializers.TruncatedNormal(stddev=0.02), b_init=tf.constant_initializer(0.0))(pad_o_c2) o_c3 = Lambda(fn=my_cond)( [donorm, InstanceNorm2d(act=None)(o_c3), o_c3]) o_c3 = Lambda(fn=lrelu)(o_c3) pad_o_c3 = PadLayer([[0, 0], [padw, padw], [padw, padw], [0, 0]], "CONSTANT")(o_c3) o_c4 = Conv2d(n_filter=ndf * 8, filter_size=(f, f), strides=(1, 1), padding="VALID", act=None, W_init=tf.initializers.TruncatedNormal(stddev=0.02), b_init=tf.constant_initializer(0.0))(pad_o_c3) o_c4 = Lambda(fn=my_cond)( [donorm, InstanceNorm2d(act=None)(o_c4), o_c4]) o_c4 = Lambda(fn=lrelu)(o_c4) pad_o_c4 = PadLayer([[0, 0], [padw, padw], [padw, padw], [0, 0]], "CONSTANT")(o_c4) o_c5 = Conv2d(n_filter=1, filter_size=(f, f), strides=(1, 1), padding="VALID", act=None, W_init=tf.initializers.TruncatedNormal(stddev=0.02), b_init=tf.constant_initializer(0.0))(pad_o_c4) return Model(inputs=[inputdisc_in, mask_in, transition_rate, donorm], outputs=o_c5)
def build_generator_9blocks(name="generator", skip=False): with tf.compat.v1.variable_scope(name): #pdb.set_trace() inputgen = Input(shape=[None, IMG_WIDTH, IMG_HEIGHT, IMG_CHANNELS], dtype=tf.float32) f = 7 ks = 3 padding = "CONSTANT" padgen = PadLayer([[0, 0], [ks, ks], [ks, ks], [0, 0]], padding)(inputgen) o_c1 = Conv2d(n_filter=ngf, filter_size=(f, f), strides=(1, 1), padding="VALID", act=None, W_init=tf.initializers.TruncatedNormal(stddev=0.02), b_init=tf.constant_initializer(0.0))(padgen) o_c1 = InstanceNorm2d(act=tf.nn.relu)(o_c1) o_c2 = Conv2d(n_filter=ngf * 2, filter_size=(ks, ks), strides=(2, 2), padding="SAME", act=None, W_init=tf.initializers.TruncatedNormal(stddev=0.02), b_init=tf.constant_initializer(0.0))(o_c1) o_c2 = InstanceNorm2d(act=tf.nn.relu)(o_c2) o_c3 = Conv2d(n_filter=ngf * 4, filter_size=(ks, ks), strides=(2, 2), padding="SAME", act=None, W_init=tf.initializers.TruncatedNormal(stddev=0.02), b_init=tf.constant_initializer(0.0))(o_c2) o_c3 = InstanceNorm2d(act=tf.nn.relu)(o_c3) o_r1 = build_resnet_block(o_c3, ngf * 4, "r1", padding) o_r2 = build_resnet_block(o_r1, ngf * 4, "r2", padding) o_r3 = build_resnet_block(o_r2, ngf * 4, "r3", padding) o_r4 = build_resnet_block(o_r3, ngf * 4, "r4", padding) o_r5 = build_resnet_block(o_r4, ngf * 4, "r5", padding) o_r6 = build_resnet_block(o_r5, ngf * 4, "r6", padding) o_r7 = build_resnet_block(o_r6, ngf * 4, "r7", padding) o_r8 = build_resnet_block(o_r7, ngf * 4, "r8", padding) o_r9 = build_resnet_block(o_r8, ngf * 4, "r9", padding) o_c4 = DeConv2d(n_filter=ngf * 2, filter_size=(ks, ks), strides=(2, 2), padding="SAME", act=None, W_init=tf.initializers.TruncatedNormal(stddev=0.02), b_init=tf.constant_initializer(0.0))(o_r9) o_c4 = InstanceNorm2d(act=tf.nn.relu)(o_c4) o_c5 = DeConv2d(n_filter=ngf, filter_size=(ks, ks), strides=(2, 2), padding="SAME", act=None, W_init=tf.initializers.TruncatedNormal(stddev=0.02), b_init=tf.constant_initializer(0.0))(o_c4) o_c5 = InstanceNorm2d(act=tf.nn.relu)(o_c5) o_c6 = Conv2d(n_filter=IMG_CHANNELS, filter_size=(f, f), strides=(1, 1), padding="SAME", act=None, W_init=tf.initializers.TruncatedNormal(stddev=0.02), b_init=tf.constant_initializer(0.0))(o_c5) if skip is True: #out_gen = Lambda(tf.nn.tanh, name="t1")(Elementwise(combine_fn=tf.add)([inputgen, o_c6])) tmp = Elementwise(combine_fn=tf.add)([inputgen, o_c6]) out_gen = Lambda(tf.nn.tanh)(tmp) else: #out_gen = Lambda(tf.nn.tanh, name="t1")(o_c6) out_gen = Lambda(tf.nn.tanh)(o_c6) return Model(inputs=inputgen, outputs=out_gen)
def autoenc_upsample(name): with tf.compat.v1.variable_scope(name): inputae = Input(shape=[None, IMG_WIDTH, IMG_HEIGHT, IMG_CHANNELS], dtype=tf.float32) f = 7 ks = 3 padding = "REFLECT" pad_input = PadLayer([[0, 0], [ks, ks], [ks, ks], [0, 0]], padding)(inputae) o_c1 = Conv2d(n_filter=ngf, filter_size=(f, f), strides=(2, 2), act=None, padding="VALID", W_init=tf.initializers.TruncatedNormal(stddev=0.02), b_init=tf.constant_initializer(0.0))(pad_input) o_c1 = InstanceNorm2d(act=tf.nn.relu)(o_c1) o_c2 = Conv2d(n_filter=ngf * 2, filter_size=(ks, ks), strides=(2, 2), padding="SAME", act=None, W_init=tf.initializers.TruncatedNormal(stddev=0.02), b_init=tf.constant_initializer(0.0))(o_c1) o_c2 = InstanceNorm2d(act=tf.nn.relu)(o_c2) o_r1 = build_resnet_block_Att(o_c2, ngf * 2, "r1", padding) size_d1 = o_r1.get_shape().as_list() o_c4 = upsamplingDeconv(o_r1, size=[size_d1[1] * 2, size_d1[2] * 2], name="up1") o_c4 = PadLayer([[0, 0], [1, 1], [1, 1], [0, 0]], padding)(o_c4) o_c4_end = Conv2d(n_filter=ngf * 2, filter_size=(3, 3), strides=(1, 1), padding="VALID", act=None, W_init=tf.initializers.TruncatedNormal(stddev=0.02), b_init=tf.constant_initializer(0.0))(o_c4) o_c4_end = InstanceNorm2d(act=tf.nn.relu)(o_c4_end) size_d2 = o_c4_end.get_shape().as_list() o_c5 = upsamplingDeconv(o_c4_end, size=[size_d2[1] * 2, size_d2[2] * 2], name="up2") o_c5 = PadLayer([[0, 0], [1, 1], [1, 1], [0, 0]], padding)(o_c5) o_c5_end = Conv2d(n_filter=ngf, filter_size=(3, 3), strides=(1, 1), padding="VALID", act=None, W_init=tf.initializers.TruncatedNormal(stddev=0.02), b_init=tf.constant_initializer(0.0))(o_c5) o_c5_end = InstanceNorm2d(act=tf.nn.relu)(o_c5_end) o_c5_end = PadLayer([[0, 0], [3, 3], [3, 3], [0, 0]], padding)(o_c5_end) o_c6_end = Conv2d(n_filter=1, filter_size=(f, f), strides=(1, 1), padding="VALID", act=None, W_init=tf.initializers.TruncatedNormal(stddev=0.02), b_init=tf.constant_initializer(0.0))(o_c5_end) output = Lambda(tf.nn.sigmoid)(o_c6_end) return Model(inputs=inputae, outputs=output)
def u_net(inputs, refine=False): w_init = tf.random_normal_initializer(stddev=0.02) g_init = tf.random_normal_initializer(1., 0.02) lrelu = lambda x: tl.act.lrelu(x, 0.2) # ENCODER conv1 = Conv2d(64, (4, 4), (2, 2), padding='SAME', W_init=w_init)(inputs) conv1 = InstanceNorm2d(act=lrelu, gamma_init=g_init)(conv1) conv2 = Conv2d(128, (4, 4), (2, 2), padding='SAME', W_init=w_init)(conv1) conv2 = InstanceNorm2d(act=lrelu, gamma_init=g_init)(conv2) conv3 = Conv2d(256, (4, 4), (2, 2), padding='SAME', W_init=w_init)(conv2) conv3 = InstanceNorm2d(act=lrelu, gamma_init=g_init)(conv3) conv4 = Conv2d(512, (4, 4), (2, 2), padding='SAME', W_init=w_init)(conv3) conv4 = InstanceNorm2d(act=lrelu, gamma_init=g_init)(conv4) conv5 = Conv2d(512, (4, 4), (2, 2), padding='SAME', W_init=w_init)(conv4) conv5 = InstanceNorm2d(act=lrelu, gamma_init=g_init)(conv5) conv6 = Conv2d(512, (4, 4), (2, 2), padding='SAME', W_init=w_init)(conv5) conv6 = InstanceNorm2d(act=lrelu, gamma_init=g_init)(conv6) conv7 = Conv2d(512, (4, 4), (2, 2), padding='SAME', W_init=w_init)(conv6) conv7 = InstanceNorm2d(act=lrelu, gamma_init=g_init)(conv7) conv8 = Conv2d(512, (4, 4), (2, 2), padding='SAME', W_init=w_init)(conv7) conv8 = InstanceNorm2d(act=lrelu, gamma_init=g_init)(conv8) # DECODER d0 = DeConv2d(n_filter=512, filter_size=(4, 4))(conv8) d0 = Dropout(0.5)(d0) d0 = Concat()( [InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(d0), conv7]) d1 = DeConv2d(n_filter=512, filter_size=(4, 4))(d0) d1 = Dropout(0.5)(d1) d1 = Concat()( [InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(d1), conv6]) d2 = DeConv2d(n_filter=512, filter_size=(4, 4))(d1) d2 = Dropout(0.5)(d2) d2 = Concat()( [InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(d2), conv5]) d3 = DeConv2d(n_filter=512, filter_size=(4, 4))(d2) d3 = Concat()( [InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(d3), conv4]) d4 = DeConv2d(n_filter=256, filter_size=(4, 4))(d3) d4 = Concat()( [InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(d4), conv3]) d5 = DeConv2d(n_filter=128, filter_size=(4, 4))(d4) d5 = Concat()( [InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(d5), conv2]) d6 = DeConv2d(n_filter=64, filter_size=(4, 4))(d5) d6 = Concat()( [InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(d6), conv1]) d7 = DeConv2d(n_filter=64, filter_size=(4, 4))(d6) d7 = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(d7) nn = Conv2d(1, (1, 1), (1, 1), act=tf.nn.tanh, padding='SAME', W_init=w_init)(d7) if refine: nn = RampElementwise(tf.add, act=tl.act.ramp, v_min=-1)([nn, inputs]) return nn