def log_p_y_z(self):

        if self.continuous:
            h_decoder  = softplus(dot(self.W_zh,self.z.T) + self.b_zh)
            if self.numHiddenLayers_decoder == 2:
                h_decoder = softplus(dot(self.W_hh, h_decoder) + self.b_hh)
            mu_decoder = dot(self.W_hy1, h_decoder) + self.b_hy1
            log_sigma_decoder = 0.5*(dot(self.W_hy2, h_decoder) + self.b_hy2)
            log_pyz    = T.sum( -(0.5 * np.log(2 * np.pi) + log_sigma_decoder) \
                                - 0.5 * ((self.y_miniBatch.T - mu_decoder) / T.exp(log_sigma_decoder))**2 )

            log_sigma_decoder.name = 'log_sigma_decoder'
            mu_decoder.name        = 'mu_decoder'
            h_decoder.name         = 'h_decoder'
            log_pyz.name           = 'log_p_y_z'
        else:
            h_decoder = tanh(dot(self.W_zh, self.z) + self.b_zh)
            if self.numHiddenLayers_decoder == 2:
                h_decoder = softplus(dot(W_hh, h_decoder) + self.b_hh)
            y_hat     = sigmoid(dot(self.W_hy1, h_decoder) + self.b_hy1)
            log_pyz   = -T.nnet.binary_crossentropy(y_hat, self.y_miniBatch).sum()
            h_decoder.name = 'h_decoder'
            y_hat.name     = 'y_hat'
            log_pyz.name   = 'log_p_y_z'

        return log_pyz
    def create_new_data_function(self):
        # self.z_test = sharedZeroMatrix(self.Q,1,'z_test')
        h_decoder  = softplus(dot(self.W_zh,self.z_test.T) + self.b_zh)
        if self.numHiddenLayers_decoder == 2:
            h_decoder = softplus(dot(self.W_hh, h_decoder) + self.b_hh)
        mu_decoder = dot(self.W_hy1, h_decoder) + self.b_hy1
        self.new_data_function = th.function([], mu_decoder, no_default_updates=True)

        return mu_decoder
    def reconstruct_test_datum(self):
        self.y_test = self.y(np.random.choice(self.N, 1))

        h_qX = softplus(plus(dot(self.W1_qX, self.y_test.T), self.b1_qX))
        mu_qX = plus(dot(self.W2_qX, h_qX), self.b2_qX)
        log_sigma_qX = mul( 0.5, plus(dot(self.W3_qX, h_qX), self.b3_qX))

        self.phi_test  = mu_qX.T  # [BxR]
        (self.Phi_test,self.cPhi_test,self.iPhi_test,self.logDetPhi_test) \
            = diagCholInvLogDet_fromLogDiag(log_sigma_qX)

        self.Xz_test = plus( self.phi_test, dot(self.cPhi_test, self.xi[0,:]))

        self.Kzz_test = kfactory.kernel(self.Xz_test, None,    self.log_theta)
        self.Kzu_test = kfactory.kernel(self.Xz_test, self.Xu, self.log_theta)

        self.A_test  = dot(self.Kzu_test, self.iKuu)
        self.C_test  = minus( self.Kzz_test, dot(self.A_test, self.Kzu_test.T))
        self.cC_test, self.iC_test, self.logDetC_test = cholInvLogDet(self.C_test, self.B, self.jitter)

        self.u_test  = plus( self.kappa, (dot(self.cKappa, self.alpha)))

        self.mu_test = dot(self.A_test, self.u_test)

        self.z_test  = plus(self.mu_test, (dot(self.cC_test, self.beta[0,:])))
    def __init__(self,
                 numberOfInducingPoints,  # Number of inducing ponts in sparse GP
                 batchSize,              # Size of mini batch
                 dimX,                   # Dimensionality of the latent co-ordinates
                 dimZ,                   # Dimensionality of the latent variables
                 data,                   # [NxP] matrix of observations
                 kernelType='ARD',
                 encoderType_qX='FreeForm2',  # 'MLP', 'Kernel'.
                 encoderType_rX='FreeForm2',  # 'MLP', 'Kernel'
                 Xu_optimise=False,
                 numberOfEncoderHiddenUnits=10
                 ):

        self.numTestSamples = 5000

        # set the data
        data = np.asarray(data, dtype=precision)
        self.N = data.shape[0]  # Number of observations
        self.P = data.shape[1]  # Dimension of each observation
        self.M = numberOfInducingPoints
        self.B = batchSize
        self.R = dimX
        self.Q = dimZ
        self.H = numberOfEncoderHiddenUnits

        self.encoderType_qX = encoderType_qX
        self.encoderType_rX = encoderType_rX
        self.Xu_optimise = Xu_optimise

        self.y = th.shared(data)
        self.y.name = 'y'

        if kernelType == 'RBF':
            self.numberOfKernelParameters = 2
        elif kernelType == 'RBFnn':
            self.numberOfKernelParameters = 1
        elif kernelType == 'ARD':
            self.numberOfKernelParameters = self.R + 1
        else:
            raise RuntimeError('Unrecognised kernel type')

        self.lowerBound = -np.inf  # Lower bound

        self.numberofBatchesPerEpoch = int(np.ceil(np.float32(self.N) / self.B))
        numPad = self.numberofBatchesPerEpoch * self.B - self.N

        self.batchStream = srng.permutation(n=self.N)
        self.padStream   = srng.choice(size=(numPad,), a=self.N,
                                       replace=False, p=None, ndim=None, dtype='int32')

        self.batchStream.name = 'batchStream'
        self.padStream.name = 'padStream'

        self.iterator = th.shared(0)
        self.iterator.name = 'iterator'

        self.allBatches = T.reshape(T.concatenate((self.batchStream, self.padStream)), [self.numberofBatchesPerEpoch, self.B])
        self.currentBatch = T.flatten(self.allBatches[self.iterator, :])

        self.allBatches.name = 'allBatches'
        self.currentBatch.name = 'currentBatch'

        self.y_miniBatch = self.y[self.currentBatch, :]
        self.y_miniBatch.name = 'y_miniBatch'

        self.jitterDefault = np.float64(0.0001)
        self.jitterGrowthFactor = np.float64(1.1)
        self.jitter = th.shared(np.asarray(self.jitterDefault, dtype='float64'), name='jitter')

        kfactory = kernelFactory(kernelType)

        # kernel parameters
        self.log_theta = sharedZeroMatrix(1, self.numberOfKernelParameters, 'log_theta', broadcastable=(True,False)) # parameters of Kuu, Kuf, Kff
        self.log_omega = sharedZeroMatrix(1, self.numberOfKernelParameters, 'log_omega', broadcastable=(True,False)) # parameters of Kuu, Kuf, Kff
        self.log_gamma = sharedZeroMatrix(1, self.numberOfKernelParameters, 'log_gamma', broadcastable=(True,False)) # parameters of Kuu, Kuf, Kff

        # Random variables
        self.xi    = srng.normal(size=(self.B, self.R), avg=0.0, std=1.0, ndim=None)
        self.alpha = srng.normal(size=(self.M, self.Q), avg=0.0, std=1.0, ndim=None)
        self.beta  = srng.normal(size=(self.B, self.Q), avg=0.0, std=1.0, ndim=None)
        self.xi.name    = 'xi'
        self.alpha.name = 'alpha'
        self.beta.name  = 'beta'

        self.sample_xi    = th.function([], self.xi)
        self.sample_alpha = th.function([], self.alpha)
        self.sample_beta  = th.function([], self.beta)

        self.sample_batchStream = th.function([], self.batchStream)
        self.sample_padStream   = th.function([], self.padStream)

        self.getCurrentBatch = th.function([], self.currentBatch, no_default_updates=True)

        # Compute parameters of q(X)
        if self.encoderType_qX == 'FreeForm1' or self.encoderType_qX == 'FreeForm2':
            # Have a normal variational distribution over location of latent co-ordinates

            self.phi_full = sharedZeroMatrix(self.N, self.R, 'phi_full')
            self.phi = self.phi_full[self.currentBatch, :]
            self.phi.name = 'phi'

            if encoderType_qX == 'FreeForm1':

                self.Phi_full_sqrt = sharedZeroMatrix(self.N, self.N, 'Phi_full_sqrt')

                Phi_batch_sqrt = self.Phi_full_sqrt[self.currentBatch][:, self.currentBatch]
                Phi_batch_sqrt.name = 'Phi_batch_sqrt'

                self.Phi = dot(Phi_batch_sqrt, Phi_batch_sqrt.T, 'Phi')

                self.cPhi, _, self.logDetPhi = cholInvLogDet(self.Phi, self.B, 0)

                self.qX_vars = [self.Phi_full_sqrt, self.phi_full]

            else:

                self.Phi_full_logdiag = sharedZeroArray(self.N, 'Phi_full_logdiag')

                Phi_batch_logdiag = self.Phi_full_logdiag[self.currentBatch]
                Phi_batch_logdiag.name = 'Phi_batch_logdiag'

                self.Phi, self.cPhi, _, self.logDetPhi \
                    = diagCholInvLogDet_fromLogDiag(Phi_batch_logdiag, 'Phi')

                self.qX_vars = [self.Phi_full_logdiag, self.phi_full]

        elif self.encoderType_qX == 'MLP':

            # Auto encode
            self.W1_qX = sharedZeroMatrix(self.H, self.P, 'W1_qX')
            self.W2_qX = sharedZeroMatrix(self.R, self.H, 'W2_qX')
            self.W3_qX = sharedZeroMatrix(1, self.H, 'W3_qX')
            self.b1_qX = sharedZeroVector(self.H, 'b1_qX', broadcastable=(False, True))
            self.b2_qX = sharedZeroVector(self.R, 'b2_qX', broadcastable=(False, True))
            self.b3_qX = sharedZeroVector(1, 'b3_qX', broadcastable=(False, True))

            # [HxB] = softplus( [HxP] . [BxP]^T + repmat([Hx1],[1,B]) )
            h_qX = softplus(plus(dot(self.W1_qX, self.y_miniBatch.T), self.b1_qX), 'h_qX' )
            # [RxB] = sigmoid( [RxH] . [HxB] + repmat([Rx1],[1,B]) )
            mu_qX = plus(dot(self.W2_qX, h_qX), self.b2_qX, 'mu_qX')
            # [1xB] = 0.5 * ( [1xH] . [HxB] + repmat([1x1],[1,B]) )
            log_sigma_qX = mul( 0.5, plus(dot(self.W3_qX, h_qX), self.b3_qX), 'log_sigma_qX')

            self.phi  = mu_qX.T  # [BxR]
            self.Phi, self.cPhi, self.iPhi,self.logDetPhi \
                = diagCholInvLogDet_fromLogDiag(log_sigma_qX, 'Phi')

            self.qX_vars = [self.W1_qX, self.W2_qX, self.W3_qX, self.b1_qX, self.b2_qX, self.b3_qX]

        elif self.encoderType_qX == 'Kernel':

            # Draw the latent coordinates from a GP with data co-ordinates
            self.Phi = kfactory.kernel(self.y_miniBatch, None, self.log_gamma, 'Phi')
            self.phi = sharedZeroMatrix(self.B, self.R, 'phi')
            (self.cPhi, self.iPhi, self.logDetPhi) = cholInvLogDet(self.Phi, self.B, self.jitter)

            self.qX_vars = [self.log_gamma]

        else:
            raise RuntimeError('Unrecognised encoding for q(X): ' + self.encoderType_qX)

        # Variational distribution q(u)
        self.kappa = sharedZeroMatrix(self.M, self.Q, 'kappa')
        self.Kappa_sqrt = sharedZeroMatrix(self.M, self.M, 'Kappa_sqrt')
        self.Kappa = dot(self.Kappa_sqrt, self.Kappa_sqrt.T, 'Kappa')

        (self.cKappa, self.iKappa, self.logDetKappa) \
                    = cholInvLogDet(self.Kappa, self.M, 0)
        self.qu_vars = [self.Kappa_sqrt, self.kappa]

        # Calculate latent co-ordinates Xf
        # [BxR]  = [BxR] + [BxB] . [BxR]
        self.Xz = plus( self.phi, dot(self.cPhi, self.xi), 'Xf' )
        # Inducing points co-ordinates
        self.Xu = sharedZeroMatrix(self.M, self.R, 'Xu')

        # Kernels
        self.Kzz = kfactory.kernel(self.Xz, None,    self.log_theta, 'Kff')
        self.Kuu = kfactory.kernel(self.Xu, None,    self.log_theta, 'Kuu')
        self.Kzu = kfactory.kernel(self.Xz, self.Xu, self.log_theta, 'Kfu')
        self.cKuu, self.iKuu, self.logDetKuu = cholInvLogDet(self.Kuu, self.M, self.jitter)

        # Variational distribution
        # A has dims [BxM] = [BxM] . [MxM]
        self.A = dot(self.Kzu, self.iKuu, 'A')
        # L is the covariance of conditional distribution q(z|u,Xf)
        self.C = minus( self.Kzz, dot(self.A, self.Kzu.T), 'C')
        self.cC, self.iC, self.logDetC = cholInvLogDet(self.C, self.B, self.jitter)

        # Sample u_q from q(u_q) = N(u_q; kappa_q, Kappa )  [MxQ]
        self.u  = plus(self.kappa, (dot(self.cKappa, self.alpha)), 'u')
        # compute mean of z [QxB]
        # [BxQ] = [BxM] * [MxQ]
        self.mu = dot(self.A, self.u, 'mu')
        # Sample f from q(f|u,X) = N( mu_q, C )
        # [BxQ] =
        self.z  = plus(self.mu, (dot(self.cC, self.beta)), 'z')

        self.qz_vars = [self.log_theta]

        self.iUpsilon = plus(self.iKappa, dot(self.A.T, dot(self.iC, self.A) ), 'iUpsilon')
        _, self.Upsilon, self.negLogDetUpsilon = cholInvLogDet(self.iUpsilon, self.M, self.jitter)

        if self.encoderType_rX == 'MLP':

            self.W1_rX = sharedZeroMatrix(self.H, self.Q+self.P, 'W1_rX')
            self.W2_rX = sharedZeroMatrix(self.R, self.H, 'W2_rX')
            self.W3_rX = sharedZeroMatrix(self.R, self.H, 'W3_rX')
            self.b1_rX = sharedZeroVector(self.H, 'b1_rX', broadcastable=(False, True))
            self.b2_rX = sharedZeroVector(self.R, 'b2_rX', broadcastable=(False, True))
            self.b3_rX = sharedZeroVector(self.R, 'b3_rX', broadcastable=(False, True))

            # [HxB] = softplus( [Hx(Q+P)] . [(Q+P)xB] + repmat([Hx1], [1,B]) )
            h_rX = softplus(plus(dot(self.W1_rX, T.concatenate((self.z.T, self.y_miniBatch.T))), self.b1_rX), 'h_rX')
            # [RxB] = softplus( [RxH] . [HxB] + repmat([Rx1], [1,B]) )
            mu_rX = plus(dot(self.W2_rX, h_rX), self.b2_rX, 'mu_rX')
            # [RxB] = 0.5*( [RxH] . [HxB] + repmat([Rx1], [1,B]) )
            log_sigma_rX = mul( 0.5, plus(dot(self.W3_rX, h_rX), self.b3_rX), 'log_sigma_rX')

            self.tau = mu_rX.T

            # Diagonal optimisation of Tau
            self.Tau_isDiagonal = True
            self.Tau = T.reshape(log_sigma_rX, [self.B * self.R, 1])
            self.logDetTau = T.sum(log_sigma_rX)
            self.Tau.name = 'Tau'
            self.logDetTau.name = 'logDetTau'

            self.rX_vars = [self.W1_rX, self.W2_rX, self.W3_rX, self.b1_rX, self.b2_rX, self.b3_rX]

        elif self.encoderType_rX == 'Kernel':

            self.tau = sharedZeroMatrix(self.B, self.R, 'tau')

            # Tau_r [BxB] = kernel( [[BxQ]^T,[BxP]^T].T )
            Tau_r = kfactory.kernel(T.concatenate((self.z.T, self.y_miniBatch.T)).T, None, self.log_omega, 'Tau_r')
            (cTau_r, iTau_r, logDetTau_r) = cholInvLogDet(Tau_r, self.B, self.jitter)

            # self.Tau  = slinalg.kron(T.eye(self.R), Tau_r)
            self.cTau = slinalg.kron(cTau_r, T.eye(self.R))
            self.iTau = slinalg.kron(iTau_r, T.eye(self.R))

            self.logDetTau = logDetTau_r * self.R
            self.tau.name  = 'tau'
            # self.Tau.name  = 'Tau'
            self.cTau.name = 'cTau'
            self.iTau.name = 'iTau'
            self.logDetTau.name = 'logDetTau'

            self.Tau_isDiagonal = False
            self.rX_vars = [self.log_omega]

        else:
            raise RuntimeError('Unrecognised encoding for r(X|z)')

        # Gradient variables - should be all the th.shared variables
        # We always want to optimise these variables
        if self.Xu_optimise:
            self.gradientVariables = [self.Xu]
        else:
            self.gradientVariables = []

        self.gradientVariables.extend(self.qu_vars)
        self.gradientVariables.extend(self.qz_vars)
        self.gradientVariables.extend(self.qX_vars)
        self.gradientVariables.extend(self.rX_vars)

        self.lowerBounds = []

        self.condKappa = myCond()(self.Kappa)
        self.condKappa.name = 'condKappa'
        self.Kappa_conditionNumber = th.function([], self.condKappa, no_default_updates=True)

        self.condKuu = myCond()(self.Kuu)
        self.condKuu.name = 'condKuu'
        self.Kuu_conditionNumber = th.function([], self.condKuu, no_default_updates=True)

        self.condC = myCond()(self.C)
        self.condC.name = 'condC'
        self.C_conditionNumber = th.function([], self.condC, no_default_updates=True)

        self.condUpsilon = myCond()(self.Upsilon)
        self.condUpsilon.name = 'condUpsilon'
        self.Upsilon_conditionNumber = th.function([], self.condUpsilon, no_default_updates=True)

        self.Xz_get_value = th.function([], self.Xz, no_default_updates=True)