Ejemplo n.º 1
0
    def create_gradients(self, loss, deterministic=False):
        from theano.gradient import disconnected_grad as dg

        # load network input
        X = self.inputs[0]
        x = X.flatten(2)

        # load network output
        if self.model == 'bernoulli':
            q_mu, q_logsigma, p_mu, z \
                 = lasagne.layers.get_output(self.network[2:], deterministic=deterministic)
        elif self.model == 'gaussian':
            raise NotImplementedError()

        # load params
        p_params, q_params = self._get_net_params()

        # entropy term
        log_qz_given_x = log_normal2(z, q_mu, q_logsigma).sum(axis=1)

        # expected p(x,z) term
        z_prior_sigma = T.cast(T.ones_like(q_logsigma),
                               dtype=theano.config.floatX)
        z_prior_mu = T.cast(T.zeros_like(q_mu), dtype=theano.config.floatX)
        log_pz = log_normal(z, z_prior_mu, z_prior_sigma).sum(axis=1)
        log_px_given_z = log_bernoulli(x, p_mu).sum(axis=1)
        log_pxz = log_pz + log_px_given_z

        # compute learning signals
        l = log_pxz - log_qz_given_x
        # l_avg, l_std = l.mean(), T.maximum(1, l.std())
        # c_new = 0.8*c + 0.2*l_avg
        # v_new = 0.8*v + 0.2*l_std
        # l = (l - c_new) / v_new

        # compute grad wrt p
        p_grads = T.grad(-log_pxz.mean(), p_params)

        # compute grad wrt q
        # q_target = T.mean(dg(l) * log_qz_given_x)
        # q_grads = T.grad(-0.2*q_target, q_params) # 5x slower rate for q
        log_qz_given_x = log_normal2(dg(z), q_mu, q_logsigma).sum(axis=1)
        q_target = T.mean(dg(l) * log_qz_given_x)
        q_grads = T.grad(-0.2 * q_target, q_params)  # 5x slower rate for q
        # q_grads = T.grad(-l.mean(), q_params) # 5x slower rate for q

        # # compute grad of cv net
        # cv_target = T.mean(l**2)
        # cv_grads = T.grad(cv_target, cv_params)

        # combine and clip gradients
        clip_grad = 1
        max_norm = 5
        grads = p_grads + q_grads
        mgrads = lasagne.updates.total_norm_constraint(grads,
                                                       max_norm=max_norm)
        cgrads = [T.clip(g, -clip_grad, clip_grad) for g in mgrads]

        return cgrads
Ejemplo n.º 2
0
    def _create_components(self, deterministic=False):
        # load network input
        X = self.inputs[0]
        x = X.flatten(2)

        # load networks
        l_p_mu, l_q_mu, l_q_sample, _, _, _ = self.network
        l_q_in, l_p_in, l_cv_in = self.input_layers

        # load network output
        z, q_mu = lasagne.layers.get_output(
            [l_q_sample, l_q_mu], deterministic=deterministic)
        p_mu = lasagne.layers.get_output(
            l_p_mu, {l_p_in: z},
            deterministic=deterministic,
        )

        # entropy term
        log_qz_given_x = log_bernoulli(dg(z), q_mu).sum(axis=1)

        # expected p(x,z) term
        z_prior = T.ones_like(z)*np.float32(0.5)
        log_pz = log_bernoulli(z, z_prior).sum(axis=1)
        log_px_given_z = log_bernoulli(x, p_mu).sum(axis=1)
        log_pxz = log_pz + log_px_given_z

        # save them for later
        self.log_pxz = log_pxz
        self.log_qz_given_x = log_qz_given_x

        return log_pxz.flatten(), log_qz_given_x.flatten()
Ejemplo n.º 3
0
    def create_gradients(self, loss, deterministic=False):
        # load networks
        l_px_mu, l_px_logsigma, l_pa_mu, l_pa_logsigma, \
        l_qa_mu, l_qa_logsigma, l_qz_mu, l_qz_logsigma, l_qa, l_qz, l_cv, c, v = self.network

        # load params
        p_params = lasagne.layers.get_all_params(
            # [l_px_mu], trainable=True)
            [l_px_mu, l_pa_mu, l_pa_logsigma],
            trainable=True)
        qa_params = lasagne.layers.get_all_params(l_qa_mu, trainable=True)
        qz_params = lasagne.layers.get_all_params(l_qz, trainable=True)
        cv_params = lasagne.layers.get_all_params(l_cv, trainable=True)

        # load neural net outputs (probabilities have been precomputed)
        log_pxz, log_px_given_z, log_pz = self.log_pxz, self.log_px_given_z, self.log_pz
        log_qza_given_x = self.log_qza_given_x
        log_qz_given_x = self.log_qz_given_x
        log_qz_given_x_dgz = self.log_qz_given_x_dgz
        cv = T.addbroadcast(lasagne.layers.get_output(l_cv), 1)

        # compute learning signals
        l0 = log_px_given_z + log_pz - log_qz_given_x  #- cv # NOTE: this disn't have q(a)
        l_avg, l_var = l0.mean(), l0.var()
        c_new = 0.8 * c + 0.2 * l_avg
        v_new = 0.8 * v + 0.2 * l_var
        l = (l0 - c_new) / T.maximum(1, T.sqrt(v_new))
        l_target = (l0 - c_new) / T.maximum(1, T.sqrt(v_new))
        # l_target = log_px_given_z + log_pz - log_qz_given_x

        # compute grad wrt p
        p_grads = T.grad(-log_pxz.mean(), p_params)

        # compute grad wrt q_a
        elbo = T.mean(log_pxz - log_qza_given_x)
        qa_grads = T.grad(-elbo, qa_params)

        # compute grad wrt q_z
        qz_target = T.mean(dg(l_target) * log_qz_given_x_dgz)
        qz_grads = T.grad(-0.2 * qz_target, qz_params)  # 5x slower rate for q
        # qz_grads = T.grad(-0.2*T.mean(l0), qz_params) # 5x slower rate for q
        # qz_grads = T.grad(-0.2*elbo, qz_params) # 5x slower rate for q

        # compute grad of cv net
        cv_target = T.mean(l0**2)
        # cv_grads = [0.2*g for g in T.grad(cv_target, cv_params)]

        # combine and clip gradients
        clip_grad = 1
        max_norm = 5
        # grads = p_grads + qa_grads + qz_grads + cv_grads
        grads = p_grads + qa_grads + qz_grads  #+ cv_grads
        mgrads = lasagne.updates.total_norm_constraint(grads,
                                                       max_norm=max_norm)
        cgrads = [T.clip(g, -clip_grad, clip_grad) for g in mgrads]

        return cgrads
    def create_gradients(self):
        k, s, n = self.n_qk, self.n_mc, self.n_visible
        (logF, logQ, logF_D_vec) = self._components
        logbnd, logZ = self.logbnd, self.logZ
        logF_D = logF_D_vec.mean()
        X = self.q_samples
        QQ = self.Qk.reshape((n,k,1))
        
        logF2 = 2.*logF
        logQ2 = 2.*logQ
        logFQ = logF - logQ # (t,)
        S = T.exp(logFQ - logZ) # (t,)
        S2 = T.exp(logF2 - logbnd - logQ2) # (t,)
        target = T.mean(S * logF)

        # get grads wrt params_p
        dlogB_W, dlogB_bv, dlogB_bh = theano.grad(logbnd, [self.W, self.vbias, self.hbias])
        dlogZ_W, dlogZ_bv, dlogZ_bh = theano.grad(target, [self.W, self.vbias, self.hbias], consider_constant=[S])
        dE_W, dE_bv, dE_bh = theano.grad(logF_D, [self.W, self.vbias, self.hbias])

        # get graps wrt params_q
        loga = - logbnd
        cv = T.exp(logZ + 0.5*loga)
        logF2a = logF2 + loga
        F2a = T.exp( logF2a )
        Q2 = T.exp(logQ2)
        cv_adj = cv**2. * Q2
        Scv = (F2a - cv_adj)/Q2
        # S = (F2a)/Q2
        # S = FQ2a

        Dq = X - QQ # (n, K, s)
        Dq *= (-Scv).reshape((1,k,s))
        dlogB_Phi = T.mean(Dq, axis=2) * self.pi.reshape(1,k) # (n, k)

        from theano.gradient import disconnected_grad as dg
        dlogB_target = T.mean(-dg(S2) * logQ)
        dlogB_Phi = theano.grad(dlogB_target, self.Phi, consider_constant=[self.XX])        

        # log-likelihood / bound gradients (combine the above)
        
        # dL_Phi = -0.5*a*dB_Phi
        dL_Phi = -0.5*dlogB_Phi
        dL_W = dE_W - dlogZ_W 
        dL_bv = dE_bv - dlogZ_bv
        dL_bh = dE_bh - dlogZ_bh
        # dL_W = dE_W - 0.5 * dlogB_W 
        # dL_bv = dE_bv - 0.5 * dlogB_bv
        # dL_bh = dE_bh - 0.5 * dlogB_bh

        dL_Theta = [dL_W, dL_bv, dL_bh]
        dlogZ_Theta = [dlogZ_W, dlogZ_bv, dlogZ_bh]
        dE_Theta = [dE_W, dE_bv, dE_bh]

        return dL_Theta, dlogZ_Theta, dE_Theta, dL_Phi
Ejemplo n.º 5
0
    def create_llik(self):
        # load network input
        X = self.inputs[0]
        x = X.flatten(2)

        # duplicate entries to take into account multiple mc samples
        n_sam = self.n_sample
        n_out = x.shape[1]
        x = x.dimshuffle(0, 'x', 1).repeat(n_sam, axis=1).reshape((-1, n_out))

        # load networks
        l_px_mu, l_px_logsigma, l_pa_mu, l_pa_logsigma, \
        l_qa_mu, l_qa_logsigma, l_qz_mu, l_qz_logsigma, l_qa, l_qz, _, _, _ = self.network
        l_qa_in, l_qz_in, l_px_in, l_cv_in = self.input_layers

        # load network output
        qa_mu, qa_logsigma, a = lasagne.layers.get_output(
            [l_qa_mu, l_qa_logsigma, l_qa], )
        qz_mu, z = lasagne.layers.get_output(
            [l_qz_mu, l_qz],
            {
                l_qz_in: a,
                l_qa_in: X
            },
        )
        pa_mu, pa_logsigma = lasagne.layers.get_output(
            [l_pa_mu, l_pa_logsigma],
            {l_px_in: z},
        )

        px_mu = lasagne.layers.get_output(l_px_mu, {l_px_in: z})

        # entropy term
        log_qa_given_x = log_normal2(a, qa_mu, qa_logsigma).sum(axis=1)
        log_qz_given_x = log_bernoulli(z, qz_mu).sum(axis=1)
        log_qz_given_x_dgz = log_bernoulli(dg(z), qz_mu).sum(axis=1)
        log_qza_given_x = log_qz_given_x + log_qa_given_x

        # log-probability term
        # z_prior = T.ones_like(z)*np.float32(0.5)
        # log_pz = log_bernoulli(z, z_prior).sum(axis=1)
        log_e = -self.rbm.free_energy(z.reshape((128 * n_sam, self.n_lat)))
        log_px_given_z = log_bernoulli(x, px_mu).sum(axis=1)
        log_pa_given_z = log_normal2(a, pa_mu, pa_logsigma).sum(axis=1)

        t = log_pa_given_z + log_px_given_z + log_e - log_qz_given_x - log_qa_given_x
        t = t.reshape([128, n_sam])

        # compute loss
        llik = Tlogsumexp(t, axis=1)  # (n_bat,)

        return T.mean(llik)
    def create_gradients(self):
        (logFx, logpa, logQ, logF_D_vec) = self._components
        logFxa = logFx + logpa
        logbnd, logZ = self.logbnd, self.logZ
        logF_D = logF_D_vec.mean()

        logQ2 = 2. * logQ
        logFQ = logFx - logQ  # (t,)
        S = T.exp(logFQ - logZ)  # (t,)
        S2 = T.exp(2 * logFxa - logbnd - logQ2)  # (t,)
        target = T.mean(S * logFx)

        # get grads wrt params_p
        dlogB_W, dlogB_bv, dlogB_bh = theano.grad(
            logbnd, [self.W, self.vbias, self.hbias])
        dlogZ_W, dlogZ_bv, dlogZ_bh = theano.grad(
            target, [self.W, self.vbias, self.hbias], consider_constant=[S])
        dE_W, dE_bv, dE_bh = theano.grad(logF_D,
                                         [self.W, self.vbias, self.hbias])

        # get grads wrt params_qx
        from theano.gradient import disconnected_grad as dg
        dlogB_target = T.mean(-dg(S2) * logQ)
        dlogB_qx = theano.grad(dlogB_target, self.get_params_qx())

        # get grads wrt params_pa
        dlogB_pa = theano.grad(T.mean(S2),
                               self.get_params_pa(),
                               consider_constant=[logbnd])

        # log-likelihood / bound gradients (combine the above)

        dL_qx = [-0.5 * g for g in dlogB_qx]
        dL_pa = [-0.5 * g for g in dlogB_pa]
        dL_W = dE_W - dlogZ_W
        dL_bv = dE_bv - dlogZ_bv
        dL_bh = dE_bh - dlogZ_bh
        # dL_W = dE_W - 0.5 * dlogB_W
        # dL_bv = dE_bv - 0.5 * dlogB_bv
        # dL_bh = dE_bh - 0.5 * dlogB_bh

        dL_Theta = [dL_W, dL_bv, dL_bh]
        dlogZ_Theta = [dlogZ_W, dlogZ_bv, dlogZ_bh]
        dE_Theta = [dE_W, dE_bv, dE_bh]

        return dL_Theta, dlogZ_Theta, dE_Theta, dL_qx, dL_pa
Ejemplo n.º 7
0
    def create_gradients(self, loss, deterministic=False):
        from theano.gradient import disconnected_grad as dg

        # load networks
        l_p_mu, l_q_mu, _, l_cv, c, v = self.network

        # load params
        p_params  = lasagne.layers.get_all_params(l_p_mu, trainable=True)
        q_params  = lasagne.layers.get_all_params(l_q_mu, trainable=True)
        cv_params = lasagne.layers.get_all_params(l_cv, trainable=True)

        # load neural net outputs (probabilities have been precomputed)
        log_pxz, log_qz_given_x = self.log_pxz, self.log_qz_given_x
        cv = T.addbroadcast(lasagne.layers.get_output(l_cv),1)

        # compute learning signals
        l = log_pxz - log_qz_given_x - cv
        l_avg, l_var = l.mean(), l.var()
        c_new = 0.8*c + 0.2*l_avg
        v_new = 0.8*v + 0.2*l_var
        l = (l - c_new) / T.maximum(1, T.sqrt(v_new))

        # compute grad wrt p
        p_grads = T.grad(-log_pxz.mean(), p_params)

        # compute grad wrt q
        q_target = T.mean(dg(l) * log_qz_given_x)
        q_grads = T.grad(-0.2*q_target, q_params) # 5x slower rate for q

        # compute grad of cv net
        cv_target = T.mean(l**2)
        cv_grads = T.grad(cv_target, cv_params)

        # combine and clip gradients
        clip_grad = 1
        max_norm = 5
        grads = p_grads + q_grads + cv_grads
        mgrads = lasagne.updates.total_norm_constraint(grads, max_norm=max_norm)
        cgrads = [T.clip(g, -clip_grad, clip_grad) for g in mgrads]

        return cgrads
Ejemplo n.º 8
0
def lower_bound(z, z_mu, x_mu, x, eq_samples, iw_samples, epsilon=1e-6):
    from theano.gradient import disconnected_grad as dg
    # reshape the variables so batch_size, eq_samples and iw_samples are
    # separate dimensions
    z = z.reshape((-1, eq_samples, iw_samples, latent_size))
    x_mu = x_mu.reshape((-1, eq_samples, iw_samples, num_features))

    # prepare x, z for broadcasting
    # size: (batch_size, eq_samples, iw_samples, num_features)
    x = x.dimshuffle(0, 'x', 'x', 1)

    # size: (batch_size, eq_samples, iw_samples, num_latent)
    z_mu = z_mu.dimshuffle(0, 'x', 'x', 1)

    log_qz_given_x = log_bernoulli(z, z_mu, eps=epsilon).sum(axis=3)
    z_prior = T.ones_like(z)*np.float32(0.5)
    log_pz = log_bernoulli(z, z_prior).sum(axis=3)
    log_px_given_z = log_bernoulli(x, x_mu, eps=epsilon).sum(axis=3)

    # Calculate the LL using log-sum-exp to avoid underflow
    log_pxz = log_pz + log_px_given_z

    # L is (bs, mc) See definition of L in appendix eq. 14
    L = log_sum_exp(log_pxz - log_qz_given_x, axis=2) + \
        T.log(1.0/T.cast(iw_samples, 'float32'))

    grads_model = T.grad(-L.mean(), p_params)

    # L_corr should correspond to equation 10 in the paper
    L_corr = L.dimshuffle(0, 1, 'x') - get_vimco_baseline(
        log_pxz - log_qz_given_x)
    g_lb_inference = T.mean(T.sum(dg(L_corr) * log_qz_given_x) + L)
    grads_inference = T.grad(-g_lb_inference, q_params)

    grads = grads_model + grads_inference
    LL = log_mean_exp(log_pz + log_px_given_z - log_qz_given_x, axis=2)
    return (LL,
            T.mean(log_qz_given_x), T.mean(log_pz), T.mean(log_px_given_z),
            grads)
Ejemplo n.º 9
0
def lower_bound(z, z_mu, x_mu, x, eq_samples, iw_samples, epsilon=1e-6):
    from theano.gradient import disconnected_grad as dg
    # reshape the variables so batch_size, eq_samples and iw_samples are
    # separate dimensions
    z = z.reshape((-1, eq_samples, iw_samples, latent_size))
    x_mu = x_mu.reshape((-1, eq_samples, iw_samples, num_features))

    # prepare x, z for broadcasting
    # size: (batch_size, eq_samples, iw_samples, num_features)
    x = x.dimshuffle(0, 'x', 'x', 1)

    # size: (batch_size, eq_samples, iw_samples, num_latent)
    z_mu = z_mu.dimshuffle(0, 'x', 'x', 1)

    log_qz_given_x = log_bernoulli(z, z_mu, eps=epsilon).sum(axis=3)
    z_prior = T.ones_like(z) * np.float32(0.5)
    log_pz = log_bernoulli(z, z_prior).sum(axis=3)
    log_px_given_z = log_bernoulli(x, x_mu, eps=epsilon).sum(axis=3)

    # Calculate the LL using log-sum-exp to avoid underflow
    log_pxz = log_pz + log_px_given_z

    # L is (bs, mc) See definition of L in appendix eq. 14
    L = log_sum_exp(log_pxz - log_qz_given_x, axis=2) + \
        T.log(1.0/T.cast(iw_samples, 'float32'))

    grads_model = T.grad(-L.mean(), p_params)

    # L_corr should correspond to equation 10 in the paper
    L_corr = L.dimshuffle(0, 1,
                          'x') - get_vimco_baseline(log_pxz - log_qz_given_x)
    g_lb_inference = T.mean(T.sum(dg(L_corr) * log_qz_given_x) + L)
    grads_inference = T.grad(-g_lb_inference, q_params)

    grads = grads_model + grads_inference
    LL = log_mean_exp(log_pz + log_px_given_z - log_qz_given_x, axis=2)
    return (LL, T.mean(log_qz_given_x), T.mean(log_pz), T.mean(log_px_given_z),
            grads)
Ejemplo n.º 10
0
    def _create_components(self, deterministic=False):
        # load network input
        X = self.inputs[0]
        x = X.flatten(2)

        # load networks
        l_px_mu, l_px_logsigma, l_pa_mu, l_pa_logsigma, \
        l_qa_mu, l_qa_logsigma, l_qz_mu, l_qz_logsigma, l_qa, l_qz, _, _, _ = self.network
        l_qa_in, l_qz_in, l_px_in = self.input_layers

        # load network output
        qa_mu, qa_logsigma, a = lasagne.layers.get_output(
            [l_qa_mu, l_qa_logsigma, l_qa], deterministic=deterministic)
        qz_mu, z = lasagne.layers.get_output(
            [l_qz_mu, l_qz],
            # {l_qz_in : T.zeros_like(qa_mu), l_qa_in : X},
            # {l_qz_in : qa_mu, l_qa_in : X},
            {
                l_qz_in: a,
                l_qa_in: X
            },
            deterministic=deterministic)
        pa_mu, pa_logsigma = lasagne.layers.get_output(
            [l_pa_mu, l_pa_logsigma], z, deterministic=deterministic)

        if self.model == 'bernoulli':
            px_mu = lasagne.layers.get_output(l_px_mu,
                                              z,
                                              deterministic=deterministic)
        elif self.model == 'gaussian':
            px_mu, px_logsigma = lasagne.layers.get_output(
                [l_px_mu, l_px_logsigma], z, deterministic=deterministic)

        # entropy term
        log_qa_given_x = log_normal2(a, qa_mu, qa_logsigma).sum(axis=1)
        log_qz_given_x = log_bernoulli(z, qz_mu).sum(axis=1)
        log_qz_given_x_dgz = log_bernoulli(dg(z), qz_mu).sum(axis=1)
        # log_qz_given_x = log_normal2(z, qz_mu, qz_logsigma).sum(axis=1)
        # log_qz_given_x_dgz = log_normal2(dg(z), qz_mu, qz_logsigma).sum(axis=1)
        log_qza_given_x = log_qz_given_x + log_qa_given_x

        # log-probability term
        z_prior = T.ones_like(z) * np.float32(0.5)
        log_pz = log_bernoulli(z, z_prior).sum(axis=1)
        # z_prior_sigma = T.cast(T.ones_like(qz_logsigma), dtype=theano.config.floatX)
        # z_prior_mu = T.cast(T.zeros_like(qz_mu), dtype=theano.config.floatX)
        # log_pz = log_normal(z, z_prior_mu,  z_prior_sigma).sum(axis=1)
        log_px_given_z = log_bernoulli(x, px_mu).sum(axis=1)
        log_pa_given_z = log_normal2(a, pa_mu, pa_logsigma).sum(axis=1)

        log_pxz = log_pa_given_z + log_px_given_z + log_pz

        # save them for later
        if deterministic == False:
            self.log_pxz = log_pxz
            self.log_px_given_z = log_px_given_z
            self.log_pz = log_pz
            self.log_qza_given_x = log_qza_given_x
            self.log_qa_given_x = log_qa_given_x
            self.log_qz_given_x = log_qz_given_x
            self.log_qz_given_x_dgz = log_qz_given_x_dgz

        # return log_paxz, log_qza_given_x
        return log_pxz, log_qza_given_x
Ejemplo n.º 11
0
    def _create_components(self, deterministic=False):
        # load network input
        X = self.inputs[0]
        x = X.flatten(2)

        # duplicate entries to take into account multiple mc samples
        n_sam = self.n_sample
        n_out = x.shape[1]
        x = x.dimshuffle(0, 'x', 1).repeat(n_sam, axis=1).reshape((-1, n_out))

        # load networks
        l_px_mu, l_px_logsigma, l_pa_mu, l_pa_logsigma, \
        l_qa_mu, l_qa_logsigma, l_qz_mu, l_qz_logsigma, l_qa, l_qz, _, _, _ = self.network
        l_qa_in, l_qz_in, l_px_in, l_cv_in = self.input_layers

        # load network output
        qa_mu, qa_logsigma, a = lasagne.layers.get_output(
            [l_qa_mu, l_qa_logsigma, l_qa],
            deterministic=deterministic,
        )
        qz_mu, z = lasagne.layers.get_output(
            [l_qz_mu, l_qz],
            {
                l_qz_in: a,
                l_qa_in: X
            },
            deterministic=deterministic,
        )
        pa_mu, pa_logsigma = lasagne.layers.get_output(
            [l_pa_mu, l_pa_logsigma],
            {l_px_in: z},
            deterministic=deterministic,
        )

        if self.model == 'bernoulli':
            px_mu = lasagne.layers.get_output(l_px_mu, {l_px_in: z},
                                              deterministic=deterministic)
        elif self.model == 'gaussian':
            px_mu, px_logsigma = lasagne.layers.get_output(
                [l_px_mu, l_px_logsigma],
                {l_px_in: z},
                deterministic=deterministic,
            )

        # entropy term
        log_qa_given_x = log_normal2(a, qa_mu, qa_logsigma).sum(axis=1)
        log_qz_given_x = log_bernoulli(z, qz_mu).sum(axis=1)
        log_qz_given_x_dgz = log_bernoulli(dg(z), qz_mu).sum(axis=1)
        log_qza_given_x = log_qz_given_x + log_qa_given_x

        # log-probability term
        z_prior = T.ones_like(z) * np.float32(0.5)
        log_pz = log_bernoulli(z, z_prior).sum(axis=1)
        log_e = -self.rbm.free_energy(z.reshape((128 * n_sam, self.n_lat)))
        log_px_given_z = log_bernoulli(x, px_mu).sum(axis=1)
        log_pa_given_z = log_normal2(a, pa_mu, pa_logsigma).sum(axis=1)

        log_pxz = log_pa_given_z + log_px_given_z + log_e

        # save them for later
        if deterministic == False:
            self.log_pxz = log_pxz
            self.log_px_given_z = log_px_given_z
            self.log_pz = log_pz
            self.log_qza_given_x = log_qza_given_x
            self.log_qa_given_x = log_qa_given_x
            self.log_qz_given_x = log_qz_given_x
            self.log_qz_given_x_dgz = log_qz_given_x_dgz
            self.log_e = log_e.mean()
            self.z = z

        # return log_paxz, log_qza_given_x
        return log_pxz, log_qza_given_x