示例#1
0
def resModule(gamma2, x, filters, dropout, s=(3, 3)): # {{{
	y = convModule(x, filters, dropout, s=s)
	y = convModule(y, filters, dropout, s=s)
	f = lambda x, y: gamma1*x + gamma2*y
	out = merge(x, y, f=f)
	if observing:
		out = Observation()(out)
	return out # }}}
示例#2
0
 def hidden_transform(self, net, d_vp, reuse=None, training=True):
     with tf.variable_scope('transformer', reuse=reuse):
         with slim.arg_scope(vpnet_argscope(training=training,
                                            center=False)):
             in_shape = net.get_shape().as_list()
             net = slim.flatten(net)
             net = merge(net, d_vp, dim=1)
             net = slim.fully_connected(net,
                                        np.prod(in_shape[1:]),
                                        scope='fc_transform')
             net = tf.reshape(net, in_shape)
             return net
示例#3
0
    def net(self, im1, im2, vp1, vp2, reuse=None, training=True):
        """Builds the full VPNet architecture with the given inputs.

        Args:
            img: Placeholder for input images
            reuse: Whether to reuse already defined variables.
            training: Whether in train or test mode

        Returns:
            dec_im: The autoencoded image
            dec_gen: The reconstructed image from cartoon and edge inputs
            disc_out: The discriminator output
            enc_im: Encoding of the image
            gen_enc: Output of the generator
        """
        enc_im1 = self.encoder(im1, reuse=reuse, training=training)
        enc_im2 = self.encoder(im2, reuse=True, training=training)

        enc_im1 = self.hidden_transform(enc_im1,
                                        vp2 - vp1,
                                        reuse=reuse,
                                        training=training)
        enc_im2 = self.hidden_transform(enc_im2,
                                        vp1 - vp2,
                                        reuse=True,
                                        training=training)

        dec_im1 = self.decoder(enc_im1, reuse=reuse, training=training)
        dec_im2 = self.decoder(enc_im2, reuse=True, training=training)

        enc_dec1 = self.encoder(
            dec_im1, reuse=True,
            training=training)  #TODO: Maybe set training to False in one usage
        enc_dec2 = self.encoder(dec_im2, reuse=True, training=training)

        enc_dec1 = self.hidden_transform(
            enc_dec1, vp1 - vp2, reuse=True,
            training=training)  #TODO: Maybe set training to False in one usage
        enc_dec2 = self.hidden_transform(enc_dec2,
                                         vp2 - vp1,
                                         reuse=True,
                                         training=training)

        dec_ed1 = self.decoder(
            enc_dec1, reuse=True,
            training=training)  #TODO: Maybe set training to False in one usage
        dec_ed2 = self.decoder(enc_dec2, reuse=True, training=training)

        # Build input for discriminator
        disc_in_fake = merge(dec_im1, dec_im2, dim=0)
        disc_in_real = merge(dec_ed1, dec_ed2, dim=0)

        disc_out_fake, _ = self.discriminator.discriminate(disc_in_fake,
                                                           reuse=reuse,
                                                           training=training)
        disc_out_real, _ = self.discriminator.discriminate(disc_in_real,
                                                           reuse=True,
                                                           training=training)

        class_in_real = merge(dec_ed1, dec_ed2, dim=0)
        class_in_fake = merge(dec_im2, dec_im1, dim=0)

        class_out_real = self.discriminator.classify(class_in_real,
                                                     3,
                                                     reuse=reuse,
                                                     training=training)
        class_out_fake = self.discriminator.classify(class_in_fake,
                                                     3,
                                                     reuse=True,
                                                     training=training)

        return dec_im1, dec_im2, dec_ed1, dec_ed2, class_out_real, class_out_fake, disc_out_real, disc_out_fake
示例#4
0
 def vp_label(self, vp1, vp2):
     return merge(vp1, vp2, dim=0)