def visual_decoder(self, bottleneck, clss, hparams): # goes from [batch, bottleneck_bits] to [batch, 64, 64, 1] with tf.variable_scope('visual_decoder', reuse=tf.AUTO_REUSE): # unbottleneck ret = tf.layers.dense(bottleneck, 1024, activation=None) ret = tf.reshape(ret, [-1, 4, 4, 64]) clss = tf.reshape(clss, [-1]) # new deconv to bring shape up ret = tf.layers.Conv2DTranspose(2 * hparams.base_depth, 4, 2, padding='SAME', activation=None)(ret) ret = ops.conditional_instance_norm(ret, clss, hparams.num_categories) ret = tf.nn.relu(ret) # new deconv to bring shape up ret = tf.layers.Conv2DTranspose(2 * hparams.base_depth, 4, 2, padding='SAME', activation=None)(ret) ret = ops.conditional_instance_norm(ret, clss, hparams.num_categories) ret = tf.nn.relu(ret) ret = tf.layers.Conv2DTranspose(2 * hparams.base_depth, 5, padding='SAME', activation=None)(ret) ret = ops.conditional_instance_norm(ret, clss, hparams.num_categories) ret = tf.nn.relu(ret) ret = tf.layers.Conv2DTranspose(2 * hparams.base_depth, 5, 2, padding='SAME', activation=None)(ret) ret = ops.conditional_instance_norm(ret, clss, hparams.num_categories) ret = tf.nn.relu(ret) ret = tf.layers.Conv2DTranspose(hparams.base_depth, 5, padding='SAME', activation=None)(ret) ret = ops.conditional_instance_norm(ret, clss, hparams.num_categories) ret = tf.nn.relu(ret) ret = tf.layers.Conv2DTranspose(hparams.base_depth, 5, 2, padding='SAME', activation=None)(ret) ret = ops.conditional_instance_norm(ret, clss, hparams.num_categories) ret = tf.nn.relu(ret) ret = tf.layers.Conv2DTranspose(hparams.base_depth, 5, padding='SAME', activation=None)(ret) ret = ops.conditional_instance_norm(ret, clss, hparams.num_categories) ret = tf.nn.relu(ret) ret = tf.layers.Conv2D(1, 5, padding='SAME', activation=None)(ret) ret = tfd.Independent(tfd.Bernoulli(logits=ret), reinterpreted_batch_ndims=3, name='image') return ret
def visual_encoder(self, inputs, clss, hparams, train): # goes from [batch, 64, 64, 1] to [batch, hidden_size] with tf.variable_scope('visual_encoder', reuse=tf.AUTO_REUSE): ret = inputs clss = tf.reshape(clss, [-1]) # conv layer, followed by instance norm + FiLM ret = tf.layers.Conv2D(hparams.base_depth, 5, 1, padding='SAME', activation=None)(ret) ret = ops.conditional_instance_norm(ret, clss, hparams.num_categories) ret = tf.nn.relu(ret) ret = tf.layers.Conv2D(hparams.base_depth, 5, 2, padding='SAME', activation=None)(ret) ret = ops.conditional_instance_norm(ret, clss, hparams.num_categories) ret = tf.nn.relu(ret) ret = tf.layers.Conv2D(2 * hparams.base_depth, 5, 1, padding='SAME', activation=None)(ret) ret = ops.conditional_instance_norm(ret, clss, hparams.num_categories) ret = tf.nn.relu(ret) ret = tf.layers.Conv2D(2 * hparams.base_depth, 5, 2, padding='SAME', activation=None)(ret) ret = ops.conditional_instance_norm(ret, clss, hparams.num_categories) ret = tf.nn.relu(ret) # new conv layer, to bring shape down ret = tf.layers.Conv2D(2 * hparams.bottleneck_bits, 4, 2, padding='SAME', activation=None)(ret) ret = ops.conditional_instance_norm(ret, clss, hparams.num_categories) ret = tf.nn.relu(ret) # new conv layer, to bring shape down ret = tf.layers.Conv2D(2 * hparams.bottleneck_bits, 4, 2, padding='SAME', activation=None)(ret) ret = ops.conditional_instance_norm(ret, clss, hparams.num_categories) ret = tf.nn.relu(ret) # ret has 1024 ret = tf.layers.flatten(ret) ret = tf.layers.dense(ret, 2 * hparams.bottleneck_bits, activation=None) return ret
def vis_encoder(self, sources_psr, targets_psr, targets_cls): base_depth = 32 num_categories = 52 bottleneck_bits = 32 sources_psr = tf.reshape(sources_psr, [-1, 64, 64, 1]) targets_psr = tf.reshape(targets_psr, [-1, 64, 64, 1]) with tf.variable_scope(tf.VariableScope(tf.AUTO_REUSE, ''), reuse=tf.AUTO_REUSE, auxiliary_name_scope=False): ret = targets_psr clss = targets_cls clss = tf.reshape(clss, [-1]) # conv layer, followed by instance norm + FiLM ret = tf.layers.Conv2D(base_depth, 5, 1, padding='SAME', activation=None)(ret) ret = ops.conditional_instance_norm(ret, clss, num_categories) ret = tf.nn.relu(ret) ret = tf.layers.Conv2D(base_depth, 5, 2, padding='SAME', activation=None)(ret) ret = ops.conditional_instance_norm(ret, clss, num_categories) ret = tf.nn.relu(ret) ret = tf.layers.Conv2D(2 * base_depth, 5, 1, padding='SAME', activation=None)(ret) ret = ops.conditional_instance_norm(ret, clss, num_categories) ret = tf.nn.relu(ret) ret = tf.layers.Conv2D(2 * base_depth, 5, 2, padding='SAME', activation=None)(ret) ret = ops.conditional_instance_norm(ret, clss, num_categories) ret = tf.nn.relu(ret) # new conv layer, to bring shape down ret = tf.layers.Conv2D(2 * bottleneck_bits, 4, 2, padding='SAME', activation=None)(ret) ret = ops.conditional_instance_norm(ret, clss, num_categories) ret = tf.nn.relu(ret) # new conv layer, to bring shape down ret = tf.layers.Conv2D(2 * bottleneck_bits, 4, 2, padding='SAME', activation=None)(ret) ret = ops.conditional_instance_norm(ret, clss, num_categories) ret = tf.nn.relu(ret) # new conv layer, to bring shape down ret = tf.layers.Conv2D(2 * bottleneck_bits, 4, 2, padding='SAME', activation=None)(ret) ret = ops.conditional_instance_norm(ret, clss, num_categories) ret = tf.nn.relu(ret) # ret has 1024 ret = tf.layers.flatten(ret) ret = tf.layers.dense(ret, bottleneck_bits, activation=None) return ret