def dis_resblock(inputs, depth=128, downsample=False, scope=None, is_training=True, first=False): with tf.variable_scope(scope, 'dis_resblock', [inputs]) as sc: net = inputs if not first: net = tf.nn.relu(net) net = conv2d(net, depth, scope='conv1') net = tf.nn.relu(net) net = conv2d(net, depth, scope='conv2') shortcut = inputs if first and downsample: shortcut = slim.avg_pool2d(shortcut, [2, 2]) shortcut = conv2d(shortcut, depth, scope="conv_sc", kernel_size=[1, 1], weights_initializer=get_initializer(relu=False)) net = slim.avg_pool2d(net, [2, 2]) else: if inputs.shape[-1] != depth or downsample: shortcut = conv2d(shortcut, depth, scope="conv_sc", kernel_size=[1, 1], weights_initializer=get_initializer(relu=False)) if downsample: shortcut = slim.avg_pool2d(shortcut, [2, 2]) net = slim.avg_pool2d(net, [2, 2]) output = shortcut + net return output
def gen_resblock(inputs, depth=128, upsample=False, scope=None, is_training=True, normalizer_fn=None, normalizer_params=None): with tf.variable_scope(scope, 'gen_resblock', [inputs]) as sc: shortcut = inputs if upsample or inputs.shape[-1] != depth: _, h, w, _ = shortcut.shape if upsample: shortcut = tf.image.resize_nearest_neighbor(shortcut, (h*2, w*2)) shortcut = conv2d(shortcut, depth, scope="conv_sc", kernel_size=[1, 1], weights_initializer=get_initializer(relu=False)) net = inputs if normalizer_fn is not None: net = normalizer_fn(net, **normalizer_params, scope="bn1") net = tf.nn.relu(net) if upsample: _, h, w, _ = inputs.shape net = tf.image.resize_nearest_neighbor(net, (h*2, w*2)) net = conv2d(net, depth, scope='conv1') if normalizer_fn is not None: net = normalizer_fn(net, **normalizer_params, scope="bn2") net = tf.nn.relu(net) net = conv2d(net, depth, scope='conv2') output = shortcut + net return output
def generator(z, is_training, y=None, scope=None, num_classes=None): inputs = tf.concat(z, 1) if args.projection: inputs = z[0] if args.unconditional: labels = None else: labels = tf.argmax(z[1], axis=1) with tf.variable_scope(scope or "generator") as scp: end_pts_collection = scp.name+"end_pts" if num_classes is None: num_classes = args.num_classes with slim.arg_scope(gen_arg_scope(is_training, end_pts_collection)): with slim.arg_scope([cond_batch_norm], n_labels=num_classes, labels=labels): gf_dim = args.gen_linear_dim net = slim.fully_connected(inputs, 4*4*gf_dim, scope="projection", weights_initializer=get_initializer(relu=False)) net = tf.reshape(net, [-1, 4, 4, gf_dim]) net = gen_resblock(net, upsample=True, scope='res1', depth=512) net = gen_resblock(net, upsample=True, scope='res2', depth=256) net = gen_resblock(net, upsample=True, scope='res3', depth=128) net = gen_resblock(net, upsample=True, scope='res4', depth=64) net = slim.batch_norm(net, **batch_norm_params(is_training), scope="bn_final") net = tf.nn.relu(net) net = conv2d(net, 3, activation_fn=tf.nn.tanh, normalizer_fn=None, normalizer_params=None, weights_initializer=get_initializer(relu=False), scope="conv_final") end_pts = slim.utils.convert_collection_to_dict(end_pts_collection) return net, end_pts
def discriminator(inputs, is_training, gen_input=None, reuse=None, scope=None, num_classes=None): with tf.variable_scope(scope or "discriminator", values=[inputs], reuse=reuse) as scp: end_pts_collection = scp.name+"end_pts" if num_classes is None: num_classes = args.num_classes with slim.arg_scope(disc_arg_scope(is_training, end_pts_collection)): net = inputs net = dis_resblock(net, first=True, downsample=True, scope='res1', depth=64) net = dis_resblock(net, downsample=True, scope='res2', depth=128) net = dis_resblock(net, downsample=True, scope='res3', depth=256) net = dis_resblock(net, downsample=True, scope='res4', depth=512) net = dis_resblock(net, downsample=True, scope='res5', depth=1024) net = tf.nn.relu(net) if args.sum_pooling: net = tf.reduce_sum(net, [1, 2], keepdims=True) else: net = tf.reduce_mean(net, [1, 2], keepdims=True) activations = tf.squeeze(net, [1, 2], name="squeeze") # [batch_size, num_filters] gan_logits = conv2d(net, 1, kernel_size=[1, 1], activation_fn=None, normalizer_fn=None, normalizer_params=None, weights_initializer=get_initializer(relu=False), scope="fc1") class_logits = conv2d(net, num_classes, kernel_size=[1, 1], activation_fn=None, normalizer_fn=None, normalizer_params=None, weights_initializer=get_initializer(relu=False), scope="fc1_ac") end_pts = slim.utils.convert_collection_to_dict(end_pts_collection) gan_logits = tf.squeeze(gan_logits, [1, 2], name="squeeze") class_logits = tf.squeeze(class_logits, [1, 2], name="squeeze") if gen_input is not None and args.projection: y = gen_input[1] embedding_W = slim.model_variable('embedding', shape=[num_classes, net.shape[-1]], initializer=get_initializer(relu=False)) if args.spectral_normalization: upd_coll = None if not reuse else "NO_OPS" embedding_W = spectral_normed_weight(embedding_W, update_collection=upd_coll) embedding = tf.matmul(y, embedding_W) # [batch_size, num_filters] gan_logits += tf.reduce_sum(embedding * activations, axis=1, keepdims=True) return gan_logits, class_logits, end_pts
def discriminator(inputs, is_training, y=None, gen_input=None, reuse=None, scope=None, num_classes=None): with tf.variable_scope(scope or "discriminator", values=[inputs], reuse=reuse) as scp: end_pts_collection = scp.name + "end_pts" if num_classes is None: num_classes = args.num_classes with slim.arg_scope(disc_arg_scope(is_training, end_pts_collection)): net = conv2d(inputs, 64, stride=1, kernel_size=[3, 3], scope="conv0_1") net = conv2d(net, 64, stride=2, kernel_size=[4, 4], scope="conv0_2") net = conv2d(net, 128, stride=1, kernel_size=[3, 3], scope="conv1_1") net = conv2d(net, 128, stride=2, kernel_size=[4, 4], scope="conv1_2") net = conv2d(net, 256, stride=1, kernel_size=[3, 3], scope="conv2_1") net = conv2d(net, 256, stride=2, kernel_size=[4, 4], scope="conv2_2") net = conv2d(net, 512, stride=1, kernel_size=[3, 3], scope="conv3") gan_logits = conv2d(net, 1, activation_fn=None, kernel_size=[4, 4], stride=1, padding="VALID", normalizer_fn=None, normalizer_params=None, scope="gan_conv4") class_logits = conv2d(net, num_classes, activation_fn=None, kernel_size=[4, 4], stride=1, padding="VALID", normalizer_fn=None, normalizer_params=None, scope="cls_conv4") end_pts = slim.utils.convert_collection_to_dict(end_pts_collection) gan_logits = tf.squeeze(gan_logits, [1, 2], name="squeeze") class_logits = tf.squeeze(class_logits, [1, 2], name="squeeze") return gan_logits, class_logits, end_pts