Exemplo n.º 1
0
    def _construct_encoder(self):
        """
        CNN encoder.
        """
        img = Input(shape=self.img_dim)
        conv_block = convnet(img, self.enc_param, bias=False)
        flat_1 = Flatten()(conv_block)
        fc_1 = bn_dense(flat_1, self.hidden_dim)
        z_mu = Dense(self.latent_dim)(fc_1)
        z_log_sigma = Dense(self.latent_dim)(fc_1)

        def sample_z(args):
            mu, log_sigma = args
            epsilon = K.random_normal(shape=K.shape(mu),
                                      mean=self.mu,
                                      stddev=self.std_dev)
            return mu + K.exp(log_sigma / 2) * epsilon

        def vae_loss_fn(x, x_decoded_mean):
            x = K.flatten(x)
            x_decoded_mean = K.flatten(x_decoded_mean)
            flat_dim = np.product(self.img_dim)
            reconst_loss = flat_dim * binary_crossentropy(x, x_decoded_mean)
            kl_loss = -0.5 * K.sum(
                1 + z_log_sigma - K.square(z_mu) - K.exp(z_log_sigma), axis=-1)
            return reconst_loss + (self.beta * kl_loss)

        z = Lambda(sample_z)([z_mu, z_log_sigma])

        encoder = Model(img, z)
        return encoder, vae_loss_fn
Exemplo n.º 2
0
    def _construct_encoder(self):
        """
        CNN encoder with mmd loss.
        """
        img = Input(shape=self.img_dim)
        conv_block = convnet(img, self.enc_param)
        flat_1 = Flatten()(conv_block)
        fc_1 = bn_dense(flat_1, self.hidden_dim)
        z = Dense(self.latent_dim)(fc_1)

        def mmd_loss_fn(sample_qz, sample_pz):
            """
            Taken mmd loss implementation from
            https://github.com/tolstikhin/wae/blob/master/wae.py 
            """
            sigma2_p = 1.**2
            C_base = 2. * self.latent_dim * self.std_dev
            n = K.shape(sample_pz)[0]
            n = K.cast(n, 'int32')
            nf = K.cast(n, 'float32')

            norms_pz = K.sum(K.square(sample_pz), axis=1, keepdims=True)
            dotprods_pz = K.dot(sample_pz, K.transpose(sample_pz))
            distances_pz = norms_pz + K.transpose(norms_pz) - 2. * dotprods_pz

            norms_qz = K.sum(K.square(sample_qz), axis=1, keepdims=True)
            dotprods_qz = K.dot(sample_qz, K.transpose(sample_qz))
            distances_qz = norms_qz + K.transpose(norms_qz) - 2. * dotprods_qz

            dotprods = K.dot(sample_qz, K.transpose(sample_pz))
            distances = norms_qz + K.transpose(norms_pz) - 2. * dotprods

            stat = 0.
            for scale in [.1, .2, .5, 1., 2., 5., 10.]:
                C = C_base * scale
                res1 = C / (C + distances_qz)
                res1 += C / (C + distances_pz)
                res1 = tf.multiply(res1, 1. - tf.eye(n))
                res1 = K.sum(res1) / (nf * nf - nf)
                res2 = C / (C + distances)
                res2 = K.sum(res2) * 2. / (nf * nf)
                stat += res1 - res2

            return stat

        def vae_loss_fn(x, x_decoded):
            reconst_loss = binary_crossentropy(K.flatten(x),
                                               K.flatten(x_decoded))
            mmd_loss = mmd_loss_fn(z, K.random_normal(K.shape(z)))

            return reconst_loss + (self.mmd_weight * mmd_loss)

        encoder = Model(img, z)
        return encoder, vae_loss_fn
Exemplo n.º 3
0
    def _construct_critic(self):
        img = Input(shape=self.img_dim)
        conv_block = convnet(img, self.d_hidden, bias=False)
        d_flat = Flatten()(conv_block)
        d_dense = bn_dense(d_flat, 1024)
        disc_out = bn_dense(d_dense, 1, activation='linear', use_bias=False)

        critic = Model(img, disc_out)
        critic.compile(optimizer=self.critic_opt(lr=self.critic_lr),
                       loss=wasserstein_loss)
        return critic
Exemplo n.º 4
0
    def _construct_encoder(self):
        """
        CNN encoder.
        """
        img = Input(shape=self.img_dim)
        conv_block = convnet(img, self.enc_param)
        flat_1 = Flatten()(conv_block)
        fc_1 = bn_dense(flat_1, self.hidden_dim)
        z = Dense(self.latent_dim)(fc_1)

        encoder = Model(img, z)
        return encoder
Exemplo n.º 5
0
    def _construct_encoder(self):
        """
        CNN encoder.
        """
        img = Input(shape=self.img_dim)
        conv_block = convnet(img, self.enc_param)
        flat_1 = Flatten()(conv_block)
        fc_1 = bn_dense(flat_1, self.hidden_dim)
        z = Dense(self.latent_dim)(fc_1)

        def vae_loss_fn(x, x_decoded):
            reconst_loss = binary_crossentropy(K.flatten(x),
                                               K.flatten(x_decoded))
            mmd_loss = mmd_loss_fn(z, K.random_normal(K.shape(z)))

            return reconst_loss + (self.mmd_weight * mmd_loss)

        encoder = Model(img, z)
        return encoder, vae_loss_fn
Exemplo n.º 6
0
    def _construct_critic(self):
        img = Input(shape=self.img_dim)
        d = convnet(img,
                    self.d_hidden,
                    batchnorm=False,
                    activation='selu',
                    bias=False)
        d_flat = Flatten()(d)
        d_fc = bn_dense(d_flat,
                        1024,
                        batchnorm=False,
                        activation='selu',
                        use_bias=False)
        disc_out = Dense(1, use_bias=False)(d_fc)

        critic = Model(img, disc_out)
        # critic.compile(optimizer=self.critic_opt(lr=self.critic_lr),
        #                loss=wasserstein_loss)
        return critic