예제 #1
0
def dis_resblock(inputs, depth=128, downsample=False, scope=None, is_training=True, first=False):
    with tf.variable_scope(scope, 'dis_resblock', [inputs]) as sc:
        net = inputs
        if not first:
            net = tf.nn.relu(net)
        net = conv2d(net, depth, scope='conv1')
        net = tf.nn.relu(net)
        net = conv2d(net, depth, scope='conv2')

        shortcut = inputs
        if first and downsample:
            shortcut = slim.avg_pool2d(shortcut, [2, 2])
            shortcut = conv2d(shortcut, depth, scope="conv_sc", kernel_size=[1, 1],
                              weights_initializer=get_initializer(relu=False))
            net = slim.avg_pool2d(net, [2, 2])
        else:
            if inputs.shape[-1] != depth or downsample:
                shortcut = conv2d(shortcut, depth, scope="conv_sc", kernel_size=[1, 1],
                                  weights_initializer=get_initializer(relu=False))
            if downsample:
                shortcut = slim.avg_pool2d(shortcut, [2, 2])
                net = slim.avg_pool2d(net, [2, 2])

        output = shortcut + net
        return output
예제 #2
0
def gen_resblock(inputs, depth=128, upsample=False, scope=None, is_training=True,
                 normalizer_fn=None, normalizer_params=None):
    with tf.variable_scope(scope, 'gen_resblock', [inputs]) as sc:
        shortcut = inputs
        if upsample or inputs.shape[-1] != depth:
            _, h, w, _ = shortcut.shape
            if upsample:
                shortcut = tf.image.resize_nearest_neighbor(shortcut, (h*2, w*2))
            shortcut = conv2d(shortcut, depth, scope="conv_sc", kernel_size=[1, 1],
                              weights_initializer=get_initializer(relu=False))

        net = inputs
        if normalizer_fn is not None:
            net = normalizer_fn(net, **normalizer_params, scope="bn1")
        net = tf.nn.relu(net)

        if upsample:
            _, h, w, _ = inputs.shape
            net = tf.image.resize_nearest_neighbor(net, (h*2, w*2))
        net = conv2d(net, depth, scope='conv1')
        if normalizer_fn is not None:
            net = normalizer_fn(net, **normalizer_params, scope="bn2")
        net = tf.nn.relu(net)
        net = conv2d(net, depth, scope='conv2')

        output = shortcut + net
        return output
예제 #3
0
def generator(z, is_training, y=None, scope=None, num_classes=None):
    inputs = tf.concat(z, 1)
    if args.projection:
        inputs = z[0]
    if args.unconditional:
        labels = None
    else:
        labels = tf.argmax(z[1], axis=1)
    with tf.variable_scope(scope or "generator") as scp:
        end_pts_collection = scp.name+"end_pts"
        if num_classes is None:
            num_classes = args.num_classes
        with slim.arg_scope(gen_arg_scope(is_training, end_pts_collection)):
            with slim.arg_scope([cond_batch_norm],
                                n_labels=num_classes,
                                labels=labels):
                gf_dim = args.gen_linear_dim
                net = slim.fully_connected(inputs, 4*4*gf_dim, scope="projection",
                                           weights_initializer=get_initializer(relu=False))
                net = tf.reshape(net, [-1, 4, 4, gf_dim])
                net = gen_resblock(net, upsample=True, scope='res1', depth=512)
                net = gen_resblock(net, upsample=True, scope='res2', depth=256)
                net = gen_resblock(net, upsample=True, scope='res3', depth=128)
                net = gen_resblock(net, upsample=True, scope='res4', depth=64)
                net = slim.batch_norm(net, **batch_norm_params(is_training), scope="bn_final")
                net = tf.nn.relu(net)
                net = conv2d(net, 3,
                             activation_fn=tf.nn.tanh,
                             normalizer_fn=None,
                             normalizer_params=None,
                             weights_initializer=get_initializer(relu=False),
                             scope="conv_final")
                end_pts = slim.utils.convert_collection_to_dict(end_pts_collection)
    return net, end_pts
예제 #4
0
def discriminator(inputs, is_training, gen_input=None, reuse=None, scope=None, num_classes=None):
    with tf.variable_scope(scope or "discriminator", values=[inputs], reuse=reuse) as scp:
        end_pts_collection = scp.name+"end_pts"
        if num_classes is None:
            num_classes = args.num_classes
        with slim.arg_scope(disc_arg_scope(is_training, end_pts_collection)):
            net = inputs
            net = dis_resblock(net, first=True, downsample=True, scope='res1', depth=64)
            net = dis_resblock(net, downsample=True, scope='res2', depth=128)
            net = dis_resblock(net, downsample=True, scope='res3', depth=256)
            net = dis_resblock(net, downsample=True, scope='res4', depth=512)
            net = dis_resblock(net, downsample=True, scope='res5', depth=1024)
            net = tf.nn.relu(net)
            if args.sum_pooling:
                net = tf.reduce_sum(net, [1, 2], keepdims=True)
            else:
                net = tf.reduce_mean(net, [1, 2], keepdims=True)
            activations = tf.squeeze(net, [1, 2], name="squeeze")  # [batch_size, num_filters]

            gan_logits = conv2d(net, 1, kernel_size=[1, 1],
                                activation_fn=None,
                                normalizer_fn=None,
                                normalizer_params=None,
                                weights_initializer=get_initializer(relu=False),
                                scope="fc1")
            class_logits = conv2d(net, num_classes, kernel_size=[1, 1],
                                  activation_fn=None,
                                  normalizer_fn=None,
                                  normalizer_params=None,
                                  weights_initializer=get_initializer(relu=False),
                                  scope="fc1_ac")
            end_pts = slim.utils.convert_collection_to_dict(end_pts_collection)
            gan_logits = tf.squeeze(gan_logits, [1, 2], name="squeeze")
            class_logits = tf.squeeze(class_logits, [1, 2], name="squeeze")
            if gen_input is not None and args.projection:
                y = gen_input[1]
                embedding_W = slim.model_variable('embedding', shape=[num_classes, net.shape[-1]],
                                                  initializer=get_initializer(relu=False))
                if args.spectral_normalization:
                    upd_coll = None if not reuse else "NO_OPS"
                    embedding_W = spectral_normed_weight(embedding_W, update_collection=upd_coll)
                embedding = tf.matmul(y, embedding_W)  # [batch_size, num_filters]
                gan_logits += tf.reduce_sum(embedding * activations, axis=1, keepdims=True)
    return gan_logits, class_logits, end_pts
예제 #5
0
def discriminator(inputs,
                  is_training,
                  y=None,
                  gen_input=None,
                  reuse=None,
                  scope=None,
                  num_classes=None):
    with tf.variable_scope(scope or "discriminator",
                           values=[inputs],
                           reuse=reuse) as scp:
        end_pts_collection = scp.name + "end_pts"
        if num_classes is None:
            num_classes = args.num_classes
        with slim.arg_scope(disc_arg_scope(is_training, end_pts_collection)):
            net = conv2d(inputs,
                         64,
                         stride=1,
                         kernel_size=[3, 3],
                         scope="conv0_1")
            net = conv2d(net,
                         64,
                         stride=2,
                         kernel_size=[4, 4],
                         scope="conv0_2")
            net = conv2d(net,
                         128,
                         stride=1,
                         kernel_size=[3, 3],
                         scope="conv1_1")
            net = conv2d(net,
                         128,
                         stride=2,
                         kernel_size=[4, 4],
                         scope="conv1_2")
            net = conv2d(net,
                         256,
                         stride=1,
                         kernel_size=[3, 3],
                         scope="conv2_1")
            net = conv2d(net,
                         256,
                         stride=2,
                         kernel_size=[4, 4],
                         scope="conv2_2")
            net = conv2d(net, 512, stride=1, kernel_size=[3, 3], scope="conv3")
            gan_logits = conv2d(net,
                                1,
                                activation_fn=None,
                                kernel_size=[4, 4],
                                stride=1,
                                padding="VALID",
                                normalizer_fn=None,
                                normalizer_params=None,
                                scope="gan_conv4")
            class_logits = conv2d(net,
                                  num_classes,
                                  activation_fn=None,
                                  kernel_size=[4, 4],
                                  stride=1,
                                  padding="VALID",
                                  normalizer_fn=None,
                                  normalizer_params=None,
                                  scope="cls_conv4")
            end_pts = slim.utils.convert_collection_to_dict(end_pts_collection)
            gan_logits = tf.squeeze(gan_logits, [1, 2], name="squeeze")
            class_logits = tf.squeeze(class_logits, [1, 2], name="squeeze")
    return gan_logits, class_logits, end_pts