Exemple #1
0
def get_Ec(x_shape=(None, flags.img_size_h, flags.img_size_w, flags.c_dim),
           name=None):
    # ref: Multimodal Unsupervised Image-to-Image Translation
    lrelu = lambda x: tl.act.lrelu(x, 0.01)
    w_init = tf.random_normal_initializer(stddev=0.02)
    channel = 64
    ni = Input(x_shape)
    n = Conv2d(channel, (7, 7), (1, 1), act=lrelu, W_init=w_init)(ni)
    for i in range(2):
        n = Conv2d(channel * 2, (3, 3), (2, 2), W_init=w_init)(n)
        n = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(n)
        channel = channel * 2

    for i in range(1, 5):
        # res block
        nn = Conv2d(channel, (3, 3), (1, 1),
                    act=None,
                    W_init=w_init,
                    b_init=None)(n)
        nn = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(nn)
        nn = Conv2d(channel, (3, 3), (1, 1),
                    act=None,
                    W_init=w_init,
                    b_init=None)(nn)
        nn = InstanceNorm2d(act=None, gamma_init=g_init)(nn)
        n = Elementwise(tf.add)([n, nn])

    n = GaussianNoise(is_always=False)(n)

    M = Model(inputs=ni, outputs=n, name=name)
    return M
Exemple #2
0
def build_resnet_block_Att(inputres, dim, name="resnet", padding="REFLECT"):
    with tf.compat.v1.variable_scope(name):
        out_res = PadLayer([[0, 0], [1, 1], [1, 1], [0, 0]], padding)(inputres)

        out_res = Conv2d(n_filter=dim,
                         filter_size=(3, 3),
                         strides=(1, 1),
                         padding="VALID",
                         act=None,
                         W_init=tf.initializers.TruncatedNormal(stddev=0.02),
                         b_init=tf.constant_initializer(0.0))(out_res)
        out_res = InstanceNorm2d(act=tf.nn.relu)(out_res)

        out_res = PadLayer([[0, 0], [1, 1], [1, 1], [0, 0]], padding)(out_res)

        out_res = Conv2d(n_filter=dim,
                         filter_size=(3, 3),
                         strides=(1, 1),
                         padding="VALID",
                         act=None,
                         W_init=tf.initializers.TruncatedNormal(stddev=0.02),
                         b_init=tf.constant_initializer(0.0))(out_res)
        out_res = InstanceNorm2d(act=None)(out_res)

        tmp = Elementwise(combine_fn=tf.add)([out_res, inputres])
        return Lambda(tf.nn.relu)(tmp)
Exemple #3
0
def get_G(a_shape=(None, flags.za_dim), c_shape=(None, flags.c_shape[0], flags.c_shape[1], flags.c_shape[2]), \
          name=None):
    ndf = 256
    na = Input(a_shape)
    nc = Input(c_shape)
    #z = Concat(-1)([na, nt])
    z = na
    nz = ExpandDims(1)(z)
    nz = ExpandDims(1)(nz)
    nz = Tile([1, c_shape[1], c_shape[2], 1])(nz)

    # res block
    nn = Conv2d(ndf, (3, 3), (1, 1), act=None, W_init=w_init, b_init=None)(nc)
    nn = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(nn)
    nn = Conv2d(ndf, (3, 3), (1, 1), act=None, W_init=w_init, b_init=None)(nn)
    nn = InstanceNorm2d(act=None, gamma_init=g_init)(nn)
    n = Elementwise(tf.add)([nc, nn])

    nd_tmp = flags.za_dim
    ndf = ndf + nd_tmp
    n = Concat(-1)([n, nz])

    # res block *3
    for i in range(1, 4):
        nn = Conv2d(ndf, (3, 3), (1, 1), act=None, W_init=w_init,
                    b_init=None)(n)
        nn = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(nn)
        nn = Conv2d(ndf, (3, 3), (1, 1), act=None, W_init=w_init,
                    b_init=None)(nn)
        nn = InstanceNorm2d(act=None, gamma_init=g_init)(nn)
        n = Elementwise(tf.add)([n, nn])

    for i in range(2):
        ndf = ndf + nd_tmp
        n = Concat(-1)([n, nz])
        nz = Tile([1, 2, 2, 1])(nz)

        n = DeConv2d(ndf // 2, (3, 3), (2, 2),
                     act=tf.nn.relu,
                     W_init=w_init,
                     b_init=None)(n)
        n = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(n)

        ndf = ndf // 2

    n = Concat(-1)([n, nz])
    n = DeConv2d(3, (1, 1), (1, 1), act=tf.nn.tanh, W_init=w_init)(n)

    M = Model(inputs=[na, nc], outputs=n, name=name)
    return M
Exemple #4
0
def get_D_content(c_shape=(None, flags.c_shape[0], flags.c_shape[1],
                           flags.c_shape[2])):
    # reference: DRIT resource code -- Pytorch implementation
    ni = Input(c_shape)
    n = Conv2d(256, (7, 7), (2, 2), act=None, W_init=w_init)(ni)
    n = InstanceNorm2d(act=lrelu, gamma_init=g_init)(n)
    n = Conv2d(256, (7, 7), (2, 2), act=None, W_init=w_init)(n)
    n = InstanceNorm2d(act=lrelu, gamma_init=g_init)(n)
    n = Conv2d(256, (7, 7), (2, 2), act=None, W_init=w_init)(n)
    n = InstanceNorm2d(act=lrelu, gamma_init=g_init)(n)
    n = Conv2d(256, (4, 4), (1, 1), act=None, padding='VALID',
               W_init=w_init)(n)
    n = InstanceNorm2d(act=lrelu, gamma_init=g_init)(n)
    n = Conv2d(1, (5, 5), (5, 5), padding='VALID', W_init=w_init)(n)
    n = Reshape(shape=[-1, 1])(n)
    return tl.models.Model(inputs=ni, outputs=n, name=None)
Exemple #5
0
def get_G(name=None):
    gf_dim = 32
    w_init = tf.random_normal_initializer(stddev=0.02)

    nx = Input((flags.batch_size, 256, 256, 3))
    n = Conv2d(gf_dim, (7, 7), (1, 1), W_init=w_init)(nx)
    n = InstanceNorm2d(act=tf.nn.relu)(n)

    n = Conv2d(gf_dim * 2, (3, 3), (2, 2), W_init=w_init)(n)
    n = InstanceNorm2d(act=tf.nn.relu)(n)

    n = Conv2d(gf_dim * 4, (3, 3), (2, 2), W_init=w_init)(n)
    n = InstanceNorm2d(act=tf.nn.relu)(n)

    for i in range(9):
        _n = Conv2d(gf_dim * 4, (3, 3), (1, 1), W_init=w_init)(n)
        _n = InstanceNorm2d(act=tf.nn.relu)(_n)
        _n = Conv2d(gf_dim * 4, (3, 3), (1, 1), W_init=w_init)(_n)
        _n = InstanceNorm2d()(_n)
        _n = Elementwise(tf.add)([n, _n])
        n = _n

    n = DeConv2d(gf_dim * 2, (3, 3), (2, 2), W_init=w_init)(n)
    n = InstanceNorm2d(act=tf.nn.relu)(n)

    n = DeConv2d(gf_dim, (3, 3), (2, 2), W_init=w_init)(n)
    n = InstanceNorm2d(act=tf.nn.relu)(n)

    n = Conv2d(3, (7, 7), (1, 1), act=tf.nn.tanh, W_init=w_init)(n)

    M = Model(inputs=nx, outputs=n, name=name)
    return M
Exemple #6
0
def get_G_zc(shape_z=(None, flags.zc_dim), gf_dim=64):
    # reference: DCGAN generator
    output_size = 64
    s16 = output_size // 16

    ni = Input(shape_z)
    nn = Dense(n_units=(gf_dim * 8 * s16 * s16), W_init=w_init,
               b_init=None)(ni)
    nn = Reshape(shape=[-1, s16, s16, gf_dim * 8])(nn)
    nn = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(nn)
    nn = DeConv2d(gf_dim * 4, (5, 5), (2, 2), W_init=w_init, b_init=None)(nn)
    nn = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(nn)
    nn = DeConv2d(gf_dim * 2, (5, 5), (2, 2), W_init=w_init, b_init=None)(nn)
    nn = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(nn)
    nn = DeConv2d(gf_dim, (5, 5), (2, 2), W_init=w_init, b_init=None)(nn)
    nn = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(nn)
    nn = DeConv2d(256, (5, 5), (2, 2), act=tf.nn.tanh, W_init=w_init)(nn)

    return tl.models.Model(inputs=ni, outputs=nn, name='Generator_zc')
def get_patch_D(input_shape, name="discriminator"):
    w_init = tf.random_normal_initializer(stddev=0.02)
    gamma_init = tf.random_normal_initializer(1., 0.02)
    df_dim = 64
    lrelu = lambda x: tl.act.lrelu(x, 0.2)

    outputs = []

    nin = Input(input_shape)
    n = Conv2d(df_dim, (4, 4), (2, 2),
               act=lrelu,
               padding='SAME',
               W_init=w_init)(nin)
    outputs.append(n)

    n = Conv2d(df_dim * 2, (4, 4), (2, 2),
               padding='SAME',
               W_init=w_init,
               b_init=None)(n)
    outputs.append(n)
    n = InstanceNorm2d(act=lrelu, gamma_init=gamma_init)(n)
    n = Conv2d(df_dim * 4, (4, 4), (2, 2),
               padding='SAME',
               W_init=w_init,
               b_init=None)(n)
    outputs.append(n)
    n = InstanceNorm2d(act=lrelu, gamma_init=gamma_init)(n)
    n = Conv2d(df_dim * 8, (4, 4), (2, 2),
               padding='SAME',
               W_init=w_init,
               b_init=None)(n)
    outputs.append(n)
    n = InstanceNorm2d(act=lrelu, gamma_init=gamma_init)(n)

    n = Conv2d(1, (1, 1), (1, 1), padding='SAME', W_init=w_init,
               b_init=None)(n)
    outputs.append(n)

    D = Model(inputs=nin, outputs=outputs, name=name)
    return D
Exemple #8
0
def get_D(name=None):
    df_dim = 64
    w_init = tf.random_normal_initializer(stddev=0.02)
    lrelu = lambda x: tl.act.lrelu(x, 0.2)

    nx = Input((flags.batch_size, 256, 256, 3))

    n = Lambda(lambda x: tf.image.random_crop(x, [flags.batch_size, 70, 70, 3]))(nx) # patchGAN
    n = Conv2d(df_dim, (4, 4), (2, 2), act=lrelu, W_init=w_init)(n)
    n = Conv2d(df_dim * 2, (4, 4), (2, 2), W_init=w_init)(n)
    n = InstanceNorm2d(act=lrelu)(n)

    n = Conv2d(df_dim * 4, (4, 4), (2, 2), W_init=w_init)(n)
    n = InstanceNorm2d(act=lrelu)(n)

    n = Conv2d(df_dim * 8, (4, 4), (2, 2), W_init=w_init)(n)
    n = InstanceNorm2d(act=lrelu)(n)

    n = Conv2d(1, (4, 4), (4, 4), padding='VALID', W_init=w_init)(n)
    n = Flatten()(n)
    assert n.shape[-1] == 1
    M = Model(inputs=nx, outputs=n, name=name)
    return M
def cycle_G(input_shape, name="dx_generator"):
    w_init = tf.random_normal_initializer(stddev=0.02)
    g_init = tf.random_normal_initializer(1., 0.02)

    nin = Input(input_shape)
    n = Conv2d(64, (7, 7), (1, 1), padding='SAME', W_init=w_init)(nin)
    n1 = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(n)

    n2 = Conv2d(128, (3, 3), (2, 2), padding='SAME', W_init=w_init)(n1)
    n2 = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(n2)

    n = Conv2d(256, (3, 3), (2, 2), padding='SAME', W_init=w_init)(n2)
    n = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(n)

    # B residual blocks
    for i in range(9):
        nn = Conv2d(256, (3, 3), (1, 1),
                    padding='SAME',
                    W_init=w_init,
                    b_init=None)(n)
        nn = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(nn)
        nn = Conv2d(256, (3, 3), (1, 1),
                    padding='SAME',
                    W_init=w_init,
                    b_init=None)(nn)
        nn = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(nn)
        nn = Elementwise(tf.add)([n, nn])
        n = nn

    n = DeConv2d(n_filter=128)(n)
    n = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(n)

    n = DeConv2d(n_filter=64)(n)
    n = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(n)

    nn = Conv2d(1, (7, 7), (1, 1),
                act=tf.nn.tanh,
                padding='SAME',
                W_init=w_init)(n)
    G = Model(inputs=nin, outputs=nn, name=name)

    return G
Exemple #10
0
def discriminator(name="discriminator"):
    with tf.compat.v1.variable_scope(name):
        inputdisc_in = Input(shape=[None, IMG_WIDTH, IMG_HEIGHT, IMG_CHANNELS],
                             dtype=tf.float32)
        mask_in = Input(shape=[None, IMG_WIDTH, IMG_HEIGHT, IMG_CHANNELS],
                        dtype=tf.float32)
        transition_rate = Input(shape=[1], dtype=tf.float32)
        donorm = Input(shape=[1], dtype=tf.float32)

        tmp = Elementwise(combine_fn=tf.greater_equal)(
            [mask_in, transition_rate])
        mask = Lambda(fn=my_cast)(tmp)
        inputdisc = Elementwise(combine_fn=tf.multiply)([inputdisc_in, mask])

        f = 4
        padw = 2
        lrelu = lambda x: tl.act.lrelu(x, 0.2)

        pad_input = PadLayer([[0, 0], [padw, padw], [padw, padw], [0, 0]],
                             "CONSTANT")(inputdisc)

        o_c1 = Conv2d(n_filter=ndf,
                      filter_size=(f, f),
                      strides=(2, 2),
                      padding="VALID",
                      act=None,
                      W_init=tf.initializers.TruncatedNormal(stddev=0.02),
                      b_init=tf.constant_initializer(0.0))(pad_input)
        #pdb.set_trace()
        o_c1 = Lambda(fn=my_cond)(
            [donorm, InstanceNorm2d(act=None)(o_c1), o_c1])
        o_c1 = Lambda(fn=lrelu)(o_c1)

        pad_o_c1 = PadLayer([[0, 0], [padw, padw], [padw, padw], [0, 0]],
                            "CONSTANT")(o_c1)

        o_c2 = Conv2d(n_filter=ndf * 2,
                      filter_size=(f, f),
                      strides=(2, 2),
                      padding="VALID",
                      act=None,
                      W_init=tf.initializers.TruncatedNormal(stddev=0.02),
                      b_init=tf.constant_initializer(0.0))(pad_o_c1)
        o_c2 = Lambda(fn=my_cond)(
            [donorm, InstanceNorm2d(act=None)(o_c2), o_c2])
        o_c2 = Lambda(fn=lrelu)(o_c2)

        pad_o_c2 = PadLayer([[0, 0], [padw, padw], [padw, padw], [0, 0]],
                            "CONSTANT")(o_c2)

        o_c3 = Conv2d(n_filter=ndf * 4,
                      filter_size=(f, f),
                      strides=(2, 2),
                      padding="VALID",
                      act=None,
                      W_init=tf.initializers.TruncatedNormal(stddev=0.02),
                      b_init=tf.constant_initializer(0.0))(pad_o_c2)
        o_c3 = Lambda(fn=my_cond)(
            [donorm, InstanceNorm2d(act=None)(o_c3), o_c3])
        o_c3 = Lambda(fn=lrelu)(o_c3)

        pad_o_c3 = PadLayer([[0, 0], [padw, padw], [padw, padw], [0, 0]],
                            "CONSTANT")(o_c3)

        o_c4 = Conv2d(n_filter=ndf * 8,
                      filter_size=(f, f),
                      strides=(1, 1),
                      padding="VALID",
                      act=None,
                      W_init=tf.initializers.TruncatedNormal(stddev=0.02),
                      b_init=tf.constant_initializer(0.0))(pad_o_c3)
        o_c4 = Lambda(fn=my_cond)(
            [donorm, InstanceNorm2d(act=None)(o_c4), o_c4])
        o_c4 = Lambda(fn=lrelu)(o_c4)

        pad_o_c4 = PadLayer([[0, 0], [padw, padw], [padw, padw], [0, 0]],
                            "CONSTANT")(o_c4)

        o_c5 = Conv2d(n_filter=1,
                      filter_size=(f, f),
                      strides=(1, 1),
                      padding="VALID",
                      act=None,
                      W_init=tf.initializers.TruncatedNormal(stddev=0.02),
                      b_init=tf.constant_initializer(0.0))(pad_o_c4)

        return Model(inputs=[inputdisc_in, mask_in, transition_rate, donorm],
                     outputs=o_c5)
Exemple #11
0
def build_generator_9blocks(name="generator", skip=False):
    with tf.compat.v1.variable_scope(name):
        #pdb.set_trace()
        inputgen = Input(shape=[None, IMG_WIDTH, IMG_HEIGHT, IMG_CHANNELS],
                         dtype=tf.float32)
        f = 7
        ks = 3
        padding = "CONSTANT"
        padgen = PadLayer([[0, 0], [ks, ks], [ks, ks], [0, 0]],
                          padding)(inputgen)

        o_c1 = Conv2d(n_filter=ngf,
                      filter_size=(f, f),
                      strides=(1, 1),
                      padding="VALID",
                      act=None,
                      W_init=tf.initializers.TruncatedNormal(stddev=0.02),
                      b_init=tf.constant_initializer(0.0))(padgen)
        o_c1 = InstanceNorm2d(act=tf.nn.relu)(o_c1)

        o_c2 = Conv2d(n_filter=ngf * 2,
                      filter_size=(ks, ks),
                      strides=(2, 2),
                      padding="SAME",
                      act=None,
                      W_init=tf.initializers.TruncatedNormal(stddev=0.02),
                      b_init=tf.constant_initializer(0.0))(o_c1)
        o_c2 = InstanceNorm2d(act=tf.nn.relu)(o_c2)

        o_c3 = Conv2d(n_filter=ngf * 4,
                      filter_size=(ks, ks),
                      strides=(2, 2),
                      padding="SAME",
                      act=None,
                      W_init=tf.initializers.TruncatedNormal(stddev=0.02),
                      b_init=tf.constant_initializer(0.0))(o_c2)
        o_c3 = InstanceNorm2d(act=tf.nn.relu)(o_c3)

        o_r1 = build_resnet_block(o_c3, ngf * 4, "r1", padding)
        o_r2 = build_resnet_block(o_r1, ngf * 4, "r2", padding)
        o_r3 = build_resnet_block(o_r2, ngf * 4, "r3", padding)
        o_r4 = build_resnet_block(o_r3, ngf * 4, "r4", padding)
        o_r5 = build_resnet_block(o_r4, ngf * 4, "r5", padding)
        o_r6 = build_resnet_block(o_r5, ngf * 4, "r6", padding)
        o_r7 = build_resnet_block(o_r6, ngf * 4, "r7", padding)
        o_r8 = build_resnet_block(o_r7, ngf * 4, "r8", padding)
        o_r9 = build_resnet_block(o_r8, ngf * 4, "r9", padding)

        o_c4 = DeConv2d(n_filter=ngf * 2,
                        filter_size=(ks, ks),
                        strides=(2, 2),
                        padding="SAME",
                        act=None,
                        W_init=tf.initializers.TruncatedNormal(stddev=0.02),
                        b_init=tf.constant_initializer(0.0))(o_r9)
        o_c4 = InstanceNorm2d(act=tf.nn.relu)(o_c4)

        o_c5 = DeConv2d(n_filter=ngf,
                        filter_size=(ks, ks),
                        strides=(2, 2),
                        padding="SAME",
                        act=None,
                        W_init=tf.initializers.TruncatedNormal(stddev=0.02),
                        b_init=tf.constant_initializer(0.0))(o_c4)
        o_c5 = InstanceNorm2d(act=tf.nn.relu)(o_c5)

        o_c6 = Conv2d(n_filter=IMG_CHANNELS,
                      filter_size=(f, f),
                      strides=(1, 1),
                      padding="SAME",
                      act=None,
                      W_init=tf.initializers.TruncatedNormal(stddev=0.02),
                      b_init=tf.constant_initializer(0.0))(o_c5)

        if skip is True:
            #out_gen = Lambda(tf.nn.tanh, name="t1")(Elementwise(combine_fn=tf.add)([inputgen, o_c6]))
            tmp = Elementwise(combine_fn=tf.add)([inputgen, o_c6])
            out_gen = Lambda(tf.nn.tanh)(tmp)
        else:
            #out_gen = Lambda(tf.nn.tanh, name="t1")(o_c6)
            out_gen = Lambda(tf.nn.tanh)(o_c6)

        return Model(inputs=inputgen, outputs=out_gen)
Exemple #12
0
def autoenc_upsample(name):
    with tf.compat.v1.variable_scope(name):
        inputae = Input(shape=[None, IMG_WIDTH, IMG_HEIGHT, IMG_CHANNELS],
                        dtype=tf.float32)
        f = 7
        ks = 3
        padding = "REFLECT"

        pad_input = PadLayer([[0, 0], [ks, ks], [ks, ks], [0, 0]],
                             padding)(inputae)

        o_c1 = Conv2d(n_filter=ngf,
                      filter_size=(f, f),
                      strides=(2, 2),
                      act=None,
                      padding="VALID",
                      W_init=tf.initializers.TruncatedNormal(stddev=0.02),
                      b_init=tf.constant_initializer(0.0))(pad_input)

        o_c1 = InstanceNorm2d(act=tf.nn.relu)(o_c1)
        o_c2 = Conv2d(n_filter=ngf * 2,
                      filter_size=(ks, ks),
                      strides=(2, 2),
                      padding="SAME",
                      act=None,
                      W_init=tf.initializers.TruncatedNormal(stddev=0.02),
                      b_init=tf.constant_initializer(0.0))(o_c1)
        o_c2 = InstanceNorm2d(act=tf.nn.relu)(o_c2)

        o_r1 = build_resnet_block_Att(o_c2, ngf * 2, "r1", padding)

        size_d1 = o_r1.get_shape().as_list()
        o_c4 = upsamplingDeconv(o_r1,
                                size=[size_d1[1] * 2, size_d1[2] * 2],
                                name="up1")
        o_c4 = PadLayer([[0, 0], [1, 1], [1, 1], [0, 0]], padding)(o_c4)
        o_c4_end = Conv2d(n_filter=ngf * 2,
                          filter_size=(3, 3),
                          strides=(1, 1),
                          padding="VALID",
                          act=None,
                          W_init=tf.initializers.TruncatedNormal(stddev=0.02),
                          b_init=tf.constant_initializer(0.0))(o_c4)

        o_c4_end = InstanceNorm2d(act=tf.nn.relu)(o_c4_end)

        size_d2 = o_c4_end.get_shape().as_list()

        o_c5 = upsamplingDeconv(o_c4_end,
                                size=[size_d2[1] * 2, size_d2[2] * 2],
                                name="up2")
        o_c5 = PadLayer([[0, 0], [1, 1], [1, 1], [0, 0]], padding)(o_c5)
        o_c5_end = Conv2d(n_filter=ngf,
                          filter_size=(3, 3),
                          strides=(1, 1),
                          padding="VALID",
                          act=None,
                          W_init=tf.initializers.TruncatedNormal(stddev=0.02),
                          b_init=tf.constant_initializer(0.0))(o_c5)

        o_c5_end = InstanceNorm2d(act=tf.nn.relu)(o_c5_end)
        o_c5_end = PadLayer([[0, 0], [3, 3], [3, 3], [0, 0]],
                            padding)(o_c5_end)
        o_c6_end = Conv2d(n_filter=1,
                          filter_size=(f, f),
                          strides=(1, 1),
                          padding="VALID",
                          act=None,
                          W_init=tf.initializers.TruncatedNormal(stddev=0.02),
                          b_init=tf.constant_initializer(0.0))(o_c5_end)

        output = Lambda(tf.nn.sigmoid)(o_c6_end)
        return Model(inputs=inputae, outputs=output)
def u_net(inputs, refine=False):
    w_init = tf.random_normal_initializer(stddev=0.02)
    g_init = tf.random_normal_initializer(1., 0.02)
    lrelu = lambda x: tl.act.lrelu(x, 0.2)

    # ENCODER
    conv1 = Conv2d(64, (4, 4), (2, 2), padding='SAME', W_init=w_init)(inputs)
    conv1 = InstanceNorm2d(act=lrelu, gamma_init=g_init)(conv1)

    conv2 = Conv2d(128, (4, 4), (2, 2), padding='SAME', W_init=w_init)(conv1)
    conv2 = InstanceNorm2d(act=lrelu, gamma_init=g_init)(conv2)

    conv3 = Conv2d(256, (4, 4), (2, 2), padding='SAME', W_init=w_init)(conv2)
    conv3 = InstanceNorm2d(act=lrelu, gamma_init=g_init)(conv3)

    conv4 = Conv2d(512, (4, 4), (2, 2), padding='SAME', W_init=w_init)(conv3)
    conv4 = InstanceNorm2d(act=lrelu, gamma_init=g_init)(conv4)

    conv5 = Conv2d(512, (4, 4), (2, 2), padding='SAME', W_init=w_init)(conv4)
    conv5 = InstanceNorm2d(act=lrelu, gamma_init=g_init)(conv5)

    conv6 = Conv2d(512, (4, 4), (2, 2), padding='SAME', W_init=w_init)(conv5)
    conv6 = InstanceNorm2d(act=lrelu, gamma_init=g_init)(conv6)

    conv7 = Conv2d(512, (4, 4), (2, 2), padding='SAME', W_init=w_init)(conv6)
    conv7 = InstanceNorm2d(act=lrelu, gamma_init=g_init)(conv7)

    conv8 = Conv2d(512, (4, 4), (2, 2), padding='SAME', W_init=w_init)(conv7)
    conv8 = InstanceNorm2d(act=lrelu, gamma_init=g_init)(conv8)

    # DECODER
    d0 = DeConv2d(n_filter=512, filter_size=(4, 4))(conv8)
    d0 = Dropout(0.5)(d0)
    d0 = Concat()(
        [InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(d0), conv7])

    d1 = DeConv2d(n_filter=512, filter_size=(4, 4))(d0)
    d1 = Dropout(0.5)(d1)
    d1 = Concat()(
        [InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(d1), conv6])

    d2 = DeConv2d(n_filter=512, filter_size=(4, 4))(d1)
    d2 = Dropout(0.5)(d2)
    d2 = Concat()(
        [InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(d2), conv5])

    d3 = DeConv2d(n_filter=512, filter_size=(4, 4))(d2)
    d3 = Concat()(
        [InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(d3), conv4])

    d4 = DeConv2d(n_filter=256, filter_size=(4, 4))(d3)
    d4 = Concat()(
        [InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(d4), conv3])

    d5 = DeConv2d(n_filter=128, filter_size=(4, 4))(d4)
    d5 = Concat()(
        [InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(d5), conv2])

    d6 = DeConv2d(n_filter=64, filter_size=(4, 4))(d5)
    d6 = Concat()(
        [InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(d6), conv1])

    d7 = DeConv2d(n_filter=64, filter_size=(4, 4))(d6)
    d7 = InstanceNorm2d(act=tf.nn.relu, gamma_init=g_init)(d7)

    nn = Conv2d(1, (1, 1), (1, 1),
                act=tf.nn.tanh,
                padding='SAME',
                W_init=w_init)(d7)

    if refine:
        nn = RampElementwise(tf.add, act=tl.act.ramp, v_min=-1)([nn, inputs])

    return nn