示例#1
0
    def _construct_critic(self):
        img = Input(shape=self.img_dim)
        conv_block = convnet(img, self.d_hidden, bias=False)
        d_flat = Flatten()(conv_block)
        d_dense = bn_dense(d_flat, 1024)
        disc_out = bn_dense(d_dense, 1, activation='linear', use_bias=False)

        critic = Model(img, disc_out)
        critic.compile(optimizer=self.critic_opt(lr=self.critic_lr),
                       loss=wasserstein_loss)
        return critic
示例#2
0
    def _construct_generator(self):
        """
        FC latent generator.
        """
        z = Input(shape=(self.noise_dim,))
        fc_3 = bn_dense(z, 64)
        fc_4 = bn_dense(fc_3, 128)
        fc_5 = bn_dense(fc_4, 256)
        code_fake = bn_dense(fc_5, self.latent_dim)

        generator = Model(z, code_fake)
        return generator
示例#3
0
    def _construct_critic(self):
        """
        FC Discriminator.
        """
        z = Input(shape=(self.latent_dim, ))
        fc_6 = bn_dense(z, 128, activation='relu')
        fc_7 = bn_dense(fc_6, 64, activation='relu')
        fc_8 = bn_dense(fc_7, 64, activation='relu')
        real_prob = bn_dense(fc_8, 1, activation='linear')

        critic = Model(z, real_prob)
        critic.compile(optimizer=self.critic_opt(lr=self.critic_learning_rate),
                       loss=[wasserstein_loss],
                       loss_weights=[self.critic_weight])
        return critic
示例#4
0
    def _construct_encoder(self):
        """
        CNN encoder.
        """
        img = Input(shape=self.img_dim)
        conv_block = convnet(img, self.enc_param, bias=False)
        flat_1 = Flatten()(conv_block)
        fc_1 = bn_dense(flat_1, self.hidden_dim)
        z_mu = Dense(self.latent_dim)(fc_1)
        z_log_sigma = Dense(self.latent_dim)(fc_1)

        def sample_z(args):
            mu, log_sigma = args
            epsilon = K.random_normal(shape=K.shape(mu),
                                      mean=self.mu,
                                      stddev=self.std_dev)
            return mu + K.exp(log_sigma / 2) * epsilon

        def vae_loss_fn(x, x_decoded_mean):
            x = K.flatten(x)
            x_decoded_mean = K.flatten(x_decoded_mean)
            flat_dim = np.product(self.img_dim)
            reconst_loss = flat_dim * binary_crossentropy(x, x_decoded_mean)
            kl_loss = -0.5 * K.sum(
                1 + z_log_sigma - K.square(z_mu) - K.exp(z_log_sigma), axis=-1)
            return reconst_loss + (self.beta * kl_loss)

        z = Lambda(sample_z)([z_mu, z_log_sigma])

        encoder = Model(img, z)
        return encoder, vae_loss_fn
示例#5
0
文件: aae.py 项目: jinac/autoencoders
    def _construct_critic(self):
        """
        FC Discriminator of latent.
        """
        z = Input(shape=(self.latent_dim, ))
        fc_6 = bn_dense(z, 64, activation=None)
        lk_act_1 = LeakyReLU(0.2)(fc_6)
        fc_7 = bn_dense(lk_act_1, 32, activation=None)
        lk_act_2 = LeakyReLU(0.2)(fc_7)
        fc_8 = bn_dense(lk_act_2, 32, activation=None)
        lk_act_3 = LeakyReLU(0.2)(fc_8)
        real_prob = bn_dense(lk_act_3, 1, activation='sigmoid')

        critic = Model(z, real_prob)
        critic.compile(optimizer=self.critic_opt(lr=self.critic_learning_rate),
                       loss='binary_crossentropy')
        return critic
示例#6
0
文件: gan.py 项目: jinac/autoencoders
    def _construct_critic(self):
        img = Input(shape=self.img_dim)
        d1 = bn_conv_layer(img,
                           self.d_hidden,
                           4,
                           2,
                           activation='selu',
                           batchnorm=False)
        d2 = bn_conv_layer(d1,
                           self.d_hidden,
                           4,
                           2,
                           activation='selu',
                           batchnorm=False)
        d3 = bn_conv_layer(d2,
                           2 * self.d_hidden,
                           4,
                           2,
                           activation='selu',
                           batchnorm=False)
        d4 = bn_conv_layer(d3,
                           4 * self.d_hidden,
                           4,
                           2,
                           activation='selu',
                           batchnorm=False)
        d5 = bn_conv_layer(d4,
                           8 * self.d_hidden,
                           4,
                           2,
                           activation='selu',
                           batchnorm=False)
        # d6 = bn_conv_layer(
        #     d5, 16 * self.d_hidden, 4, 2, activation='selu', batchnorm=False)
        # d7 = bn_conv_layer(
        #     d6, 32 * self.d_hidden, 4, 2, activation='selu', batchnorm=False)
        # d8 = bn_conv_layer(
        #     d7, 64 * self.d_hidden, 4, 2, activation='selu', batchnorm=False)
        d_flat = Flatten()(d5)
        d_dense = bn_dense(d_flat, 1024)
        disc_out = bn_dense(d_dense, 1, activation='sigmoid')

        critic = Model(img, disc_out)
        critic.compile(optimizer=self.critic_opt(lr=self.critic_lr),
                       loss='binary_crossentropy')
        return critic
示例#7
0
    def _construct_encoder(self):
        """
        CNN encoder with mmd loss.
        """
        img = Input(shape=self.img_dim)
        conv_block = convnet(img, self.enc_param)
        flat_1 = Flatten()(conv_block)
        fc_1 = bn_dense(flat_1, self.hidden_dim)
        z = Dense(self.latent_dim)(fc_1)

        def mmd_loss_fn(sample_qz, sample_pz):
            """
            Taken mmd loss implementation from
            https://github.com/tolstikhin/wae/blob/master/wae.py 
            """
            sigma2_p = 1.**2
            C_base = 2. * self.latent_dim * self.std_dev
            n = K.shape(sample_pz)[0]
            n = K.cast(n, 'int32')
            nf = K.cast(n, 'float32')

            norms_pz = K.sum(K.square(sample_pz), axis=1, keepdims=True)
            dotprods_pz = K.dot(sample_pz, K.transpose(sample_pz))
            distances_pz = norms_pz + K.transpose(norms_pz) - 2. * dotprods_pz

            norms_qz = K.sum(K.square(sample_qz), axis=1, keepdims=True)
            dotprods_qz = K.dot(sample_qz, K.transpose(sample_qz))
            distances_qz = norms_qz + K.transpose(norms_qz) - 2. * dotprods_qz

            dotprods = K.dot(sample_qz, K.transpose(sample_pz))
            distances = norms_qz + K.transpose(norms_pz) - 2. * dotprods

            stat = 0.
            for scale in [.1, .2, .5, 1., 2., 5., 10.]:
                C = C_base * scale
                res1 = C / (C + distances_qz)
                res1 += C / (C + distances_pz)
                res1 = tf.multiply(res1, 1. - tf.eye(n))
                res1 = K.sum(res1) / (nf * nf - nf)
                res2 = C / (C + distances)
                res2 = K.sum(res2) * 2. / (nf * nf)
                stat += res1 - res2

            return stat

        def vae_loss_fn(x, x_decoded):
            reconst_loss = binary_crossentropy(K.flatten(x),
                                               K.flatten(x_decoded))
            mmd_loss = mmd_loss_fn(z, K.random_normal(K.shape(z)))

            return reconst_loss + (self.mmd_weight * mmd_loss)

        encoder = Model(img, z)
        return encoder, vae_loss_fn
示例#8
0
    def _construct_encoder(self):
        """
        CNN encoder.
        """
        img = Input(shape=self.img_dim)
        conv_block = convnet(img, self.enc_param)
        flat_1 = Flatten()(conv_block)
        fc_1 = bn_dense(flat_1, self.hidden_dim)
        z = Dense(self.latent_dim)(fc_1)

        encoder = Model(img, z)
        return encoder
示例#9
0
文件: aae.py 项目: jinac/autoencoders
    def _construct_encoder(self):
        """
        CNN encoder.
        """
        img = Input(shape=self.img_dim)
        d1 = bn_conv_layer(img, self.img_dim[-1], 4, 2)
        d2 = bn_conv_layer(d1, 64, 4, 2)
        d3 = bn_conv_layer(d2, 128, 4, 2)
        d4 = bn_conv_layer(d3, 256, 4, 2)
        flat_1 = Flatten()(d4)
        fc_1 = bn_dense(flat_1, self.hidden_dim)
        z = Dense(self.latent_dim)(fc_1)

        encoder = Model(img, z)
        return encoder
示例#10
0
    def _construct_encoder(self):
        """
        CNN encoder.
        """
        img = Input(shape=self.img_dim)
        conv_block = convnet(img, self.enc_param)
        flat_1 = Flatten()(conv_block)
        fc_1 = bn_dense(flat_1, self.hidden_dim)
        z = Dense(self.latent_dim)(fc_1)

        def vae_loss_fn(x, x_decoded):
            reconst_loss = binary_crossentropy(K.flatten(x),
                                               K.flatten(x_decoded))
            mmd_loss = mmd_loss_fn(z, K.random_normal(K.shape(z)))

            return reconst_loss + (self.mmd_weight * mmd_loss)

        encoder = Model(img, z)
        return encoder, vae_loss_fn
示例#11
0
    def _construct_critic(self):
        img = Input(shape=self.img_dim)
        d = convnet(img,
                    self.d_hidden,
                    batchnorm=False,
                    activation='selu',
                    bias=False)
        d_flat = Flatten()(d)
        d_fc = bn_dense(d_flat,
                        1024,
                        batchnorm=False,
                        activation='selu',
                        use_bias=False)
        disc_out = Dense(1, use_bias=False)(d_fc)

        critic = Model(img, disc_out)
        # critic.compile(optimizer=self.critic_opt(lr=self.critic_lr),
        #                loss=wasserstein_loss)
        return critic
示例#12
0
    def _construct_generator(self):
        z = Input(shape=(self.in_dim, ))
        z1 = bn_dense(z, 512)
        z_reshp = Reshape((1, 1, 512))(z1)
        deconv_block = deconvnet(z_reshp,
                                 self.img_dim,
                                 self.g_hidden,
                                 activation='selu',
                                 batchnorm=True,
                                 bias=False)
        gen_img = bn_deconv_layer(deconv_block,
                                  self.img_dim[-1],
                                  4,
                                  2,
                                  activation='tanh',
                                  batchnorm=False,
                                  use_bias=False)

        generator = Model(z, gen_img)
        # generator.compile(optimizer=self.critic_opt(lr=self.critic_lr),
        #                   loss='mse')
        return generator