Beispiel #1
0
    def kldiv(self):
        M, logdets = self.sample_z()
        logdets = logdets[0]
        M = tf.squeeze(M)

        std_mg = self.get_params_W()
        qm0 = self.get_params_m()
        if len(M.get_shape()) == 0:
            Mexp = M
        else:
            Mexp = tf.expand_dims(M, 1)

        Mtilde = Mexp * self.mu_W
        Vtilde = tf.square(std_mg)

        iUp = outer(tf.exp(self.pvar), ones_d((self.output_dim, )))

        logqm = 0.
        if self.use_z:
            logqm = -tf.reduce_sum(.5 * (tf.log(2 * np.pi) + tf.log(qm0) + 1))
            logqm -= logdets

        kldiv_w = tf.reduce_sum(.5 * tf.log(iUp) - tf.log(std_mg) +
                                ((Vtilde + tf.square(Mtilde)) /
                                 (2 * iUp)) - .5)
        kldiv_bias = tf.reduce_sum(
            .5 * self.pvar_bias - .5 * self.logvar_bias +
            ((tf.exp(self.logvar_bias) + tf.square(self.mu_bias)) /
             (2 * tf.exp(self.pvar_bias))) - .5)

        if self.use_z:
            apvar_M = self.apvar_M
            # shared network for hidden layer
            mw = tf.matmul(tf.expand_dims(apvar_M, 0), Mtilde)
            eps = tf.expand_dims(tf.random_normal((self.output_dim, )), 0)
            varw = tf.matmul(tf.square(tf.expand_dims(apvar_M, 0)), Vtilde)
            a = tf.nn.tanh(mw + tf.sqrt(varw) * eps)
            # split at output layer
            if len(tf.squeeze(a).get_shape()) != 0:
                w__ = tf.reduce_mean(outer(self.rsr_M, tf.squeeze(a)), axis=1)
                wv__ = tf.reduce_mean(outer(self.rsri_M, tf.squeeze(a)),
                                      axis=1)
            else:
                w__ = self.rsr_M * tf.squeeze(a)
                wv__ = self.rsri_M * tf.squeeze(a)

            logrm = 0.
            if self.flow_r is not None:
                M, logrm = self.flow_r.get_output_for(tf.expand_dims(M, 0))
                M = tf.squeeze(M)
                logrm = logrm[0]

            logrm += tf.reduce_sum(-.5 * tf.exp(wv__) * tf.square(M - w__) -
                                   .5 * tf.log(2 * np.pi) + .5 * wv__)
        else:
            logrm = 0.

        return -kldiv_w + logrm - logqm - kldiv_bias
Beispiel #2
0
    def kldiv(self):
        M, logdets = self.sample_z()
        logdets = logdets[0]
        M = tf.squeeze(M)

        std_w = self.get_params_W()
        mu = tf.reshape(self.mu_W, [-1, self.nb_filter])
        std_w = tf.reshape(std_w, [-1, self.nb_filter])
        Mtilde = mu * tf.expand_dims(M, 0)
        mbias = self.mu_bias * M
        Vtilde = tf.square(std_w)

        iUp = outer(tf.exp(self.pvar), ones_d((self.nb_filter, )))

        qm0 = self.get_params_m()
        logqm = 0.
        if self.use_z > 0.:
            logqm = -tf.reduce_sum(.5 * (tf.log(2 * np.pi) + tf.log(qm0) + 1))
            logqm -= logdets

        kldiv_w = tf.reduce_sum(.5 * tf.log(iUp) - .5 * tf.log(Vtilde) +
                                ((Vtilde + tf.square(Mtilde)) /
                                 (2 * iUp)) - .5)
        kldiv_bias = tf.reduce_sum(
            .5 * self.pvar_bias - .5 * self.logvar_bias +
            ((tf.exp(self.logvar_bias) + tf.square(mbias)) /
             (2 * tf.exp(self.pvar_bias))) - .5)

        logrm = 0.
        if self.use_z:
            apvar_M = self.apvar_M
            mw = tf.matmul(Mtilde, tf.expand_dims(apvar_M, 1))
            vw = tf.matmul(Vtilde, tf.expand_dims(tf.square(apvar_M), 1))
            eps = tf.expand_dims(tf.random_normal((self.input_dim, )), 1)
            a = mw + tf.sqrt(vw) * eps
            mb = tf.reduce_sum(mbias * apvar_M)
            vb = tf.reduce_sum(tf.exp(self.logvar_bias) * tf.square(apvar_M))
            a += mb + tf.sqrt(vb) * tf.random_normal(())

            w__ = tf.reduce_mean(outer(tf.squeeze(a), self.rsr_M), axis=0)
            wv__ = tf.reduce_mean(outer(tf.squeeze(a), self.rsri_M), axis=0)

            if self.flow_r is not None:
                M, logrm = self.flow_r.get_output_for(tf.expand_dims(M, 0))
                M = tf.squeeze(M)
                logrm = logrm[0]

            logrm += tf.reduce_sum(-.5 * tf.exp(wv__) * tf.square(M - w__) -
                                   .5 * tf.log(2 * np.pi) + .5 * wv__)

        return -kldiv_w + logrm - logqm - kldiv_bias
Beispiel #3
0
    def sample_z(self, size_M=1, sample=True):
        if not self.use_z:
            return ones_d((size_M, self.nb_filter)), zeros_d((size_M, ))
        qm0 = self.get_params_m()
        isample_M = tf.tile(tf.expand_dims(self.qzero_mean, 0), [size_M, 1])
        eps = tf.random_normal(tf.stack((size_M, self.nb_filter)))
        sample_M = isample_M + tf.sqrt(qm0) * eps if sample else isample_M

        logdets = zeros_d((size_M, ))
        if self.n_flows_q > 0:
            sample_M, logdets = self.flow_q.get_output_for(sample_M,
                                                           sample=sample)

        return sample_M, logdets