def discriminator_model(pair, activation="relu"): with nn.model_arg_scope(activation=activation): outpair = list() for z in pair: nc = dsize h = nn.nin(z, nc) for _ in range(4): h = nn.residual_block(h, conv=nn.nin) h = nn.activate(h) h = nn.nin(h, nc) outpair.append(h) h = outpair[0] * outpair[1] h = tf.reduce_sum(h, [1, 2, 3]) h = tf.expand_dims(h, -1) return h
def encoder_model(x, out_size, config, extra_resnets, activation="relu"): with nn.model_arg_scope(activation=activation): h = nn.conv2d(x, config[0]) h = nn.residual_block(h) for nf in config[1:]: h = nn.downsample(h, nf) h = nn.residual_block(h) for _ in range(extra_resnets): h = nn.residual_block(h) h = nn.activate(h) h = tf.reduce_mean(h, [1, 2], keepdims=True) h = nn.nin(h, out_size) return h
def image_discriminator_model(x, config=None, activation="relu", coords=False): """ returns props, logits """ with nn.model_arg_scope(activation=activation, coords=coords): hs = list() h = nn.conv2d(x, config[0]) hs.append(h) for nf in config[1:]: h = nn.downsample(h, nf) h = nn.residual_block(h) hs.append(h) h = nn.activate(h) h = nn.conv2d(h, config[-1]) h = tf.reduce_mean(h, [1, 2], keepdims=True) h = nn.nin(h, 1) h = tf.reduce_mean(h, [1, 2, 3]) h = tf.expand_dims(h, -1) return tf.nn.sigmoid(h), h
def pretty_discriminator_model(x, c): with nn.model_arg_scope(activation="relu"): hs = list() h = nn.conv2d(x, convconf[0]) hs.append(h) for nf in convconf[1:]: h = nn.downsample(h, nf) h = nn.residual_block(h) hs.append(h) h = nn.activate(h) h = nn.conv2d(h, dsize) h = tf.reduce_mean(h, [1, 2], keepdims=True) hc = nn.nin(c, dsize) hc = nn.residual_block(hc, conv=nn.nin) hc = nn.residual_block(hc, conv=nn.nin) h = h * hc h = tf.reduce_mean(h, [1, 2, 3]) h = tf.expand_dims(h, -1) return h, hs