def BaselineBlock(x, bottleneck_ratio, name, update_collection=None, act=tf.nn.relu): with tf.variable_scope(name): input_channels = x.shape.as_list()[-1] out_channels = input_channels // bottleneck_ratio x_0 = x x = act(x) x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, update_collection=update_collection, name='sn_conv1') x = act(x) x = ops.snconv2d(x, input_channels, 3, 3, 1, 1, update_collection=update_collection, name='sn_conv2') sigma = tf.get_variable('sigma_ratio', [], initializer=tf.constant_initializer(0.0)) return x_0 + sigma * x
def OptimizedBlock(x, out_channels, name, update_collection=tf.GraphKeys.UPDATE_OPS, act=tf.nn.relu): with tf.variable_scope(name): x_0 = x x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, update_collection=update_collection, name='sn_conv1') x = act(x) x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, update_collection=update_collection, name='sn_conv2') x = dsample(x) x_0 = dsample(x_0) x_0 = ops.snconv2d(x_0, out_channels, 1, 1, 1, 1, update_collection=update_collection, name='sn_conv3') return x + x_0
def Block(x, labels, out_channels, num_classes, name): with tf.variable_scope(name): bn0 = sn_ops.ConditionalBatchNorm(num_classes, name='cbn_0') bn1 = sn_ops.ConditionalBatchNorm(num_classes, name='cbn_1') x_0 = x x = tf.nn.relu(bn0(x, labels)) x = usample(x) x = sn_ops.snconv2d(x, out_channels, 3, 3, 1, 1, name='snconv1') x = tf.nn.relu(bn1(x, labels)) x = sn_ops.snconv2d(x, out_channels, 3, 3, 1, 1, name='snconv2') x_0 = usample(x_0) x_0 = sn_ops.snconv2d(x_0, out_channels, 1, 1, 1, 1, name='snconv3') return x_0 + x
def Block(x, out_channels, name, update_collection=tf.GraphKeys.UPDATE_OPS, downsample=True, act=tf.nn.relu): with tf.variable_scope(name): input_channels = x.shape.as_list()[-1] x_0 = x x = act(x) x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, update_collection=update_collection, name='sn_conv1') x = act(x) x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, update_collection=update_collection, name='sn_conv2') if downsample: x = dsample(x) if downsample or input_channels != out_channels: x_0 = ops.snconv2d(x_0, out_channels, 1, 1, 1, 1, update_collection=update_collection, name='sn_conv3') if downsample: x_0 = dsample(x_0) return x_0 + x
def generator_test(zs, target_class, gf_dim, num_classes, reuse_vars=False): if reuse_vars: tf.get_variable_scope().reuse_variables() act0 = sn_ops.snlinear(zs, gf_dim * 16 * 4 * 4, name='g_snh0') act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16]) act1 = Block(act0, target_class, gf_dim * 16, num_classes, 'g_block1') act2 = Block(act1, target_class, gf_dim * 8, num_classes, 'g_block2') act3 = Block(act2, target_class, gf_dim * 4, num_classes, 'g_block3') act3 = non_local.sn_non_local_block_sim(act3, None, name='g_non_local') act4 = Block(act3, target_class, gf_dim * 2, num_classes, 'g_block4') act5 = Block(act4, target_class, gf_dim, num_classes, 'g_block5') bn = sn_ops.BatchNorm(name='g_bn') act5 = tf.nn.relu(bn(act5)) act6 = sn_ops.snconv2d(act5, 3, 3, 3, 1, 1, name='g_snconv_last') out = tf.nn.tanh(act6) return out