コード例 #1
0
ファイル: model.py プロジェクト: liuchangshiye/SIFA
def build_resnet_block_ins(inputres, dim, name="resnet", padding="REFLECT"):
    with tf.variable_scope(name):
        out_res = tf.pad(inputres, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)
        out_res = layers.general_conv2d_ga(out_res,
                                           dim,
                                           3,
                                           3,
                                           1,
                                           1,
                                           0.02,
                                           "VALID",
                                           "c1",
                                           norm_type='Ins')
        out_res = tf.pad(out_res, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)
        out_res = layers.general_conv2d_ga(out_res,
                                           dim,
                                           3,
                                           3,
                                           1,
                                           1,
                                           0.02,
                                           "VALID",
                                           "c2",
                                           do_relu=False,
                                           norm_type='Ins')

        return tf.nn.relu(out_res + inputres)
コード例 #2
0
ファイル: model.py プロジェクト: zwq1230/SIFA
def build_generator_resnet_9blocks(inputgen, inputimg, name="generator", skip=False):
    with tf.variable_scope(name):
        f = 7
        ks = 3
        padding = "CONSTANT"

        pad_input = tf.pad(inputgen, [[0, 0], [ks, ks], [ks, ks], [0, 0]], padding)
        o_c1 = layers.general_conv2d_ga(pad_input, ngf, f, f, 1, 1, 0.02, name="c1", norm_type='Ins')
        o_c2 = layers.general_conv2d_ga(o_c1, ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c2", norm_type='Ins')
        o_c3 = layers.general_conv2d_ga(o_c2, ngf * 4, ks, ks, 2, 2, 0.02, "SAME", "c3", norm_type='Ins')

        o_r1 = build_resnet_block_ins(o_c3, ngf * 4, "r1", padding)
        o_r2 = build_resnet_block_ins(o_r1, ngf * 4, "r2", padding)
        o_r3 = build_resnet_block_ins(o_r2, ngf * 4, "r3", padding)
        o_r4 = build_resnet_block_ins(o_r3, ngf * 4, "r4", padding)
        o_r5 = build_resnet_block_ins(o_r4, ngf * 4, "r5", padding)
        o_r6 = build_resnet_block_ins(o_r5, ngf * 4, "r6", padding)
        o_r7 = build_resnet_block_ins(o_r6, ngf * 4, "r7", padding)
        o_r8 = build_resnet_block_ins(o_r7, ngf * 4, "r8", padding)
        o_r9 = build_resnet_block_ins(o_r8, ngf * 4, "r9", padding)

        o_c4 = layers.general_deconv2d(o_r9, [BATCH_SIZE, 128, 128, ngf * 2], ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c4", norm_type='Ins')
        o_c5 = layers.general_deconv2d(o_c4, [BATCH_SIZE, 256, 256, ngf], ngf, ks, ks, 2, 2, 0.02, "SAME", "c5", norm_type='Ins')
        o_c6 = layers.general_conv2d_ga(o_c5, 1, f, f, 1, 1, 0.02, "SAME", "c6", do_norm=False, do_relu=False)

        if skip is True:
            out_gen = tf.nn.tanh(inputimg + o_c6, "t1")
        else:
            out_gen = tf.nn.tanh(o_c6, "t1")

        return out_gen