예제 #1
0
    def __init__(self,
                 kernel,
                 inducing_variables,
                 mean_function,
                 white=False,
                 **kwargs):
        super().__init__(**kwargs)

        self.inducing_points = inducing_variables

        self.num_inducing = inducing_variables.shape[0]
        m = inducing_variables.shape[1]

        # Initialise q_mu to y^2_pi(i)
        q_mu = np.zeros((self.num_inducing, 1))
        self.q_mu = Parameter(q_mu, dtype=default_float())

        # Initialise q_sqrt to near deterministic. Store as lower triangular matrix L.
        q_sqrt = 1e-4 * np.eye(self.num_inducing, dtype=default_float())
        self.q_sqrt = Parameter(q_sqrt, transform=triangular())

        self.kernel = kernel
        self.mean_function = mean_function
        self.white = white

        # Initialise to prior (Ku) + jitter.
        if not self.white:
            Ku = self.kernel(self.inducing_points)
            Ku += default_jitter() * tf.eye(self.num_inducing, dtype=Ku.dtype)
            Lu = tf.linalg.cholesky(Ku)
            q_sqrt = Lu
            self.q_sqrt = Parameter(q_sqrt, transform=triangular())
예제 #2
0
    def _init_variational_parameters(self, num_inducing, q_mu, q_sqrt, q_diag):
        """
        Constructs the mean and cholesky of the covariance of the variational Gaussian posterior.
        If a user passes values for `q_mu` and `q_sqrt` the routine checks if they have consistent
        and correct shapes. If a user does not specify any values for `q_mu` and `q_sqrt`, the routine
        initializes them, their shape depends on `num_inducing` and `q_diag`.

        Note: most often the comments refer to the number of observations (=output dimensions) with P,
        number of latent GPs with L, and number of inducing points M. Typically P equals L,
        but when certain multioutput kernels are used, this can change.

        Parameters
        ----------
        :param num_inducing: int
            Number of inducing variables, typically refered to as M.
        :param q_mu: np.array or None
            Mean of the variational Gaussian posterior. If None the function will initialise
            the mean with zeros. If not None, the shape of `q_mu` is checked.
        :param q_sqrt: np.array or None
            Cholesky of the covariance of the variational Gaussian posterior.
            If None the function will initialise `q_sqrt` with identity matrix.
            If not None, the shape of `q_sqrt` is checked, depending on `q_diag`.
        :param q_diag: bool
            Used to check if `q_mu` and `q_sqrt` have the correct shape or to
            construct them with the correct shape. If `q_diag` is true,
            `q_sqrt` is two dimensional and only holds the square root of the
            covariance diagonal elements. If False, `q_sqrt` is three dimensional.
        """
        q_mu = np.zeros(
            (num_inducing, self.num_latent_gps)) if q_mu is None else q_mu
        self.q_mu = Parameter(q_mu, dtype=default_float())  # [M, P]

        if q_sqrt is None:
            if self.q_diag:
                ones = np.ones((num_inducing, self.num_latent_gps),
                               dtype=default_float())
                self.q_sqrt = Parameter(ones, transform=positive())  # [M, P]
            else:
                q_sqrt = [
                    np.eye(num_inducing, dtype=default_float())
                    for _ in range(self.num_latent_gps)
                ]
                q_sqrt = np.array(q_sqrt)
                self.q_sqrt = Parameter(q_sqrt,
                                        transform=triangular())  # [P, M, M]
        else:
            if q_diag:
                assert q_sqrt.ndim == 2
                self.num_latent_gps = q_sqrt.shape[1]
                self.q_sqrt = Parameter(q_sqrt,
                                        transform=positive())  # [M, L|P]
            else:
                assert q_sqrt.ndim == 3
                self.num_latent_gps = q_sqrt.shape[0]
                num_inducing = q_sqrt.shape[1]
                self.q_sqrt = Parameter(q_sqrt,
                                        transform=triangular())  # [L|P, M, M]
예제 #3
0
    def __init__(self,
                 kern,
                 Z,
                 num_outputs,
                 mean_function,
                 white=False,
                 input_prop_dim=None,
                 **kwargs):
        """
        A sparse variational GP layer in whitened representation. This layer holds the kernel,
        variational parameters, inducing points and mean function.

        The underlying model at inputs X is
        f = Lv + mean_function(X), where v \sim N(0, I) and LL^T = kern.K(X)

        The variational distribution over the inducing points is
        q(v) = N(q_mu, q_sqrt q_sqrt^T)

        The layer holds D_out independent GPs with the same kernel and inducing points.

        :param kern: The kernel for the layer (input_dim = D_in)
        :param Z: Inducing points (M, D_in)
        :param num_outputs: The number of GP outputs (q_mu is shape (M, num_outputs))
        :param mean_function: The mean function
        :return:
        """
        super().__init__(input_prop_dim=input_prop_dim, **kwargs)
        self.num_inducing = Z.shape[0]

        # Inducing points prior mean
        q_mu = np.zeros((self.num_inducing, num_outputs))
        self.q_mu = Parameter(q_mu, name="q_mu")
        # Square-root of inducing points prior covariance
        q_sqrt = np.tile(
            np.eye(self.num_inducing)[None, :, :], [num_outputs, 1, 1])
        self.q_sqrt = Parameter(q_sqrt, transform=triangular(), name="q_sqrt")

        self.feature = InducingPoints(Z)
        self.kern = kern
        self.mean_function = mean_function

        self.num_outputs = num_outputs
        self.white = white

        if not self.white:  # initialize to prior
            Ku = self.kern.K(Z)
            Lu = np.linalg.cholesky(Ku + np.eye(Z.shape[0]) *
                                    gpflow.default_jitter())
            self.q_sqrt = Parameter(np.tile(Lu[None, :, :],
                                            [num_outputs, 1, 1]),
                                    transform=triangular(),
                                    name="q_sqrt")

        self.Ku, self.Lu, self.Ku_tiled, self.Lu_tiled = None, None, None, None
        self.needs_build_cholesky = True
예제 #4
0
파일: layers.py 프로젝트: MattAshman/MLMI4
    def __init__(self,
                 kernel,
                 inducing_variables,
                 num_outputs,
                 mean_function,
                 input_prop_dim=None,
                 white=False,
                 **kwargs):
        super().__init__(input_prop_dim, **kwargs)

        self.num_inducing = inducing_variables.shape[0]
        self.mean_function = mean_function
        self.num_outputs = num_outputs
        self.white = white

        self.kernels = []
        for i in range(self.num_outputs):
            self.kernels.append(copy.deepcopy(kernel))

        # Initialise q_mu to all zeros
        q_mu = np.zeros((self.num_inducing, num_outputs))
        self.q_mu = Parameter(q_mu, dtype=default_float())

        # Initialise q_sqrt to identity function
        #q_sqrt = tf.tile(tf.expand_dims(tf.eye(self.num_inducing,
        #    dtype=default_float()), 0), (num_outputs, 1, 1))
        q_sqrt = [
            np.eye(self.num_inducing, dtype=default_float())
            for _ in range(num_outputs)
        ]
        q_sqrt = np.array(q_sqrt)
        # Store as lower triangular matrix L.
        self.q_sqrt = Parameter(q_sqrt, transform=triangular())

        # Initialise to prior (Ku) + jitter.
        if not self.white:
            Kus = [
                self.kernels[i].K(inducing_variables)
                for i in range(self.num_outputs)
            ]
            Lus = [
                np.linalg.cholesky(Kus[i] + np.eye(self.num_inducing) *
                                   default_jitter())
                for i in range(self.num_outputs)
            ]
            q_sqrt = Lus
            q_sqrt = np.array(q_sqrt)
            self.q_sqrt = Parameter(q_sqrt, transform=triangular())

        self.inducing_points = []
        for i in range(self.num_outputs):
            self.inducing_points.append(
                inducingpoint_wrapper(inducing_variables))
예제 #5
0
    def __init__(self,
                 kernel,
                 inducing_variables,
                 q_mu_initial,
                 q_sqrt_initial,
                 mean_function,
                 white=False,
                 **kwargs):
        super().__init__(**kwargs)

        self.inducing_points = inducing_variables
        self.num_inducing = inducing_variables.shape[0]

        # Initialise q_mu to y^2_pi(i)
        q_mu = q_mu_initial[:, None]
        self.q_mu = Parameter(q_mu, dtype=default_float())

        # Initialise q_sqrt to near deterministic. Store as lower triangular matrix L.
        q_sqrt = 1e-4 * np.eye(self.num_inducing, dtype=default_float())
        #q_sqrt = np.diag(q_sqrt_initial)
        self.q_sqrt = Parameter(q_sqrt, transform=triangular())

        self.kernel = kernel
        self.mean_function = mean_function
        self.white = white
예제 #6
0
    def init_variational_params(self, num_inducing):
        q_mu = np.zeros(
            (num_inducing, self.num_kernels, self.num_latent_gps))  # M x K x O
        self.q_mu = Parameter(q_mu, dtype=default_float())

        q_sqrt = []
        for _ in range(self.num_kernels):
            q_sqrt.append([
                np.eye(num_inducing, dtype=default_float())
                for _ in range(self.num_latent_gps)
            ])
        q_sqrt = np.array(q_sqrt)
        self.q_sqrt = Parameter(q_sqrt,
                                transform=triangular())  # K x O x M x M
예제 #7
0
    def __init__(
        self,
        data: OutputData,
        Xp_mean: tf.Tensor,
        Xp_var: tf.Tensor,
        pi: tf.Tensor,
        kernel_K: List[Kernel],
        Zp: tf.Tensor,
        Xs_mean=None,
        Xs_var=None,
        kernel_s=None,
        Zs=None,
        Xs_prior_mean=None,
        Xs_prior_var=None,
        Xp_prior_mean=None,
        Xp_prior_var=None,
        pi_prior=None
    ):
        super().__init__(
            data=data,
            split_space=True, 
            Xp_mean=Xp_mean,
            Xp_var=Xp_var,
            pi=pi,
            kernel_K=kernel_K,
            Zp=Zp,
            Xs_mean=Xs_mean,
            Xs_var=Xs_var,
            kernel_s=kernel_s,
            Zs=Zs,
            Xs_prior_mean=Xs_prior_mean,
            Xs_prior_var=Xs_prior_var,
            Xp_prior_mean=Xp_prior_mean,
            Xp_prior_var=Xp_prior_var,
            pi_prior=pi_prior
        )
        # q(Us | Ms, Ss)
        q_mu = np.zeros((self.M, self.D))
        self.q_mu_s = Parameter(q_mu, dtype=default_float())  # [M, D]

        q_sqrt = [
            np.eye(self.M, dtype=default_float()) for _ in range(self.D)
        ]
        q_sqrt = np.array(q_sqrt)
        self.q_sqrt_s = Parameter(q_sqrt, transform=triangular())  # [D, M, M]
예제 #8
0
def main(args):
    datasets = Datasets(data_path=args.data_path)

    # Prepare output files
    outname1 = '../tmp/' + args.dataset + '_' + str(args.num_layers) + '_'\
            + str(args.num_inducing) + '.nll'
    if not os.path.exists(os.path.dirname(outname1)):
        os.makedirs(os.path.dirname(outname1))
    outfile1 = open(outname1, 'w')
    outname2 = '../tmp/' + args.dataset + '_' + str(args.num_layers) + '_'\
            + str(args.num_inducing) + '.time'
    outfile2 = open(outname2, 'w')

    running_loss = 0
    running_time = 0
    for i in range(args.splits):
        print('Split: {}'.format(i))
        print('Getting dataset...')
        data = datasets.all_datasets[args.dataset].get_data(i)
        X, Y, Xs, Ys, Y_std = [
            data[_] for _ in ['X', 'Y', 'Xs', 'Ys', 'Y_std']
        ]
        Z = kmeans2(X, args.num_inducing, minit='points')[0]

        # set up batches
        batch_size = args.M if args.M < X.shape[0] else X.shape[0]
        train_dataset = tf.data.Dataset.from_tensor_slices((X, Y)).repeat()\
                .prefetch(X.shape[0]//2)\
                .shuffle(buffer_size=(X.shape[0]//2))\
                .batch(batch_size)

        print('Setting up DGP model...')
        kernels = []
        for l in range(args.num_layers):
            kernels.append(SquaredExponential() + White(variance=1e-5))

        dgp_model = DGP(X.shape[1],
                        kernels,
                        Gaussian(variance=0.05),
                        Z,
                        num_outputs=Y.shape[1],
                        num_samples=args.num_samples,
                        num_data=X.shape[0])

        # initialise inner layers almost deterministically
        for layer in dgp_model.layers[:-1]:
            layer.q_sqrt = Parameter(layer.q_sqrt.value() * 1e-5,
                                     transform=triangular())

        optimiser = tf.optimizers.Adam(args.learning_rate)

        def optimisation_step(model, X, Y):
            with tf.GradientTape() as tape:
                tape.watch(model.trainable_variables)
                obj = -model.elbo(X, Y, full_cov=False)
                grad = tape.gradient(obj, model.trainable_variables)
            optimiser.apply_gradients(zip(grad, model.trainable_variables))

        def monitored_training_loop(model, train_dataset, logdir, iterations,
                                    logging_iter_freq):
            # TODO: use tensorboard to log trainables and performance
            tf_optimisation_step = tf.function(optimisation_step)
            batches = iter(train_dataset)

            for i in range(iterations):
                X, Y = next(batches)
                tf_optimisation_step(model, X, Y)

                iter_id = i + 1
                if iter_id % logging_iter_freq == 0:
                    tf.print(
                        f'Epoch {iter_id}: ELBO (batch) {model.elbo(X, Y)}')

        print('Training DGP model...')
        t0 = time.time()
        monitored_training_loop(dgp_model,
                                train_dataset,
                                logdir=args.log_dir,
                                iterations=args.iterations,
                                logging_iter_freq=args.logging_iter_freq)
        t1 = time.time()
        print('Time taken to train: {}'.format(t1 - t0))
        outfile2.write('Split {}: {}\n'.format(i + 1, t1 - t0))
        outfile2.flush()
        os.fsync(outfile2.fileno())
        running_time += t1 - t0

        m, v = dgp_model.predict_y(Xs, num_samples=args.test_samples)
        test_nll = np.mean(
            logsumexp(norm.logpdf(Ys * Y_std, m * Y_std, v**0.5 * Y_std),
                      0,
                      b=1 / float(args.test_samples)))
        print('Average test log likelihood: {}'.format(test_nll))
        outfile1.write('Split {}: {}\n'.format(i + 1, test_nll))
        outfile1.flush()
        os.fsync(outfile1.fileno())
        running_loss += t1 - t0

    outfile1.write('Average: {}\n'.format(running_loss / args.splits))
    outfile2.write('Average: {}\n'.format(running_time / args.splits))
    outfile1.close()
    outfile2.close()
예제 #9
0
    def __init__(
        self,
        inducing_variable: gpflow.inducing_variables.InducingVariables,
        kernel: gpflow.kernels.Kernel,
        domain: np.ndarray,
        q_mu: np.ndarray,
        q_S: np.ndarray,
        *,
        beta0: float = 1e-6,
        num_observations: int = 1,
        num_events: Optional[int] = None,
    ):
        """
        D = number of dimensions
        M = size of inducing variables (number of inducing points)

        :param inducing_variable: inducing variables (here only implemented for a gpflow
            .inducing_variables.InducingPoints instance, with Z of shape M x D)
        :param kernel: the kernel (here only implemented for a gpflow.kernels
            .SquaredExponential instance)
        :param domain: lower and upper bounds of (hyper-rectangular) domain
            (D x 2)

        :param q_mu: initial mean vector of the variational distribution q(u)
            (length M)
        :param q_S: how to initialise the covariance matrix of the variational
            distribution q(u)  (M x M)

        :param beta0: a constant offset, corresponding to initial value of the
            prior mean of the GP (but trainable); should be sufficiently large
            so that the GP does not go negative...

        :param num_observations: number of observations of sets of events
            under the distribution

        :param num_events: total number of events, defaults to events.shape[0]
            (relevant when feeding in minibatches)
        """
        super().__init__(kernel, likelihood=None)  # custom likelihood

        # observation domain  (D x 2)
        self.domain = domain
        if domain.ndim != 2 or domain.shape[1] != 2:
            raise ValueError("domain must be of shape D x 2")

        self.num_observations = num_observations
        self.num_events = num_events

        if not (isinstance(kernel, gpflow.kernels.SquaredExponential)
                and isinstance(inducing_variable,
                               gpflow.inducing_variables.InducingPoints)):
            raise NotImplementedError(
                "This VBPP implementation can only handle real-space "
                "inducing points together with the SquaredExponential "
                "kernel.")
        self.kernel = kernel
        self.inducing_variable = inducing_variable

        self.beta0 = Parameter(beta0, transform=positive(),
                               name="beta0")  # constant mean offset

        # variational approximate Gaussian posterior q(u) = N(u; m, S)
        self.q_mu = Parameter(q_mu, name="q_mu")  # mean vector  (length M)

        # covariance:
        L = np.linalg.cholesky(
            q_S)  # S = L L^T, with L lower-triangular  (M x M)
        self.q_sqrt = Parameter(L, transform=triangular(), name="q_sqrt")

        self.psi_jitter = 0.0
예제 #10
0
def test_triangular():
    assert isinstance(triangular(), tfp.bijectors.FillTriangular)
예제 #11
0
def main(args):
    datasets = Datasets(data_path=args.data_path)

    # prepare output files
    outname1 = args.results_dir + args.dataset + '_' + str(args.num_layers) + '_'\
            + str(args.num_inducing) + '.rmse'
    if not os.path.exists(os.path.dirname(outname1)):
        os.makedirs(os.path.dirname(outname1))
    outfile1 = open(outname1, 'w')

    outname2 = args.results_dir + args.dataset + '_' + str(args.num_layers) + '_'\
            + str(args.num_inducing) + '.nll'
    outfile2 = open(outname2, 'w')

    outname3 = args.results_dir + args.dataset + '_' + str(args.num_layers) + '_'\
            + str(args.num_inducing) + '.time'
    outfile3 = open(outname3, 'w')

    # =========================================================================
    # CROSS-VALIDATION LOOP
    # =========================================================================
    running_err = 0
    running_loss = 0
    running_time = 0
    test_errs = np.zeros(args.splits)
    test_nlls = np.zeros(args.splits)
    test_times = np.zeros(args.splits)
    for i in range(args.splits):
        # =====================================================================
        # MODEL CONSTRUCTION
        # =====================================================================
        print('Split: {}'.format(i))
        print('Getting dataset...')
        # get dataset
        data = datasets.all_datasets[args.dataset].get_data(
            i, normalize=args.normalize_data)
        X, Y, Xs, Ys, Y_std = [
            data[_] for _ in ['X', 'Y', 'Xs', 'Ys', 'Y_std']
        ]

        # inducing points via k-means
        Z = kmeans2(X, args.num_inducing, minit='points')[0]

        # set up batches
        batch_size = args.M if args.M < X.shape[0] else X.shape[0]
        train_dataset = tf.data.Dataset.from_tensor_slices((X, Y)).repeat()\
            .prefetch(X.shape[0]//2)\
            .shuffle(buffer_size=(X.shape[0]//2))\
            .batch(batch_size)

        print('Setting up DGP model...')
        kernels = []
        dims = []

        # hidden_dim = min(args.max_dim, X.shape[1])
        hidden_dim = X.shape[1] if X.shape[1] < args.max_dim else args.max_dim
        for l in range(args.num_layers):
            if l == 0:
                dim = X.shape[1]
                dims.append(dim)
            else:
                dim = hidden_dim
                dims.append(dim)

            if args.ard:
                # SE kernel with lengthscale per dimension
                kernels.append(
                    SquaredExponential(lengthscale=[1.] * dim) +
                    White(variance=1e-5))
            else:
                # SE kernel with single lengthscale
                kernels.append(
                    SquaredExponential(lengthscale=1.) + White(variance=1e-5))

        # output dim
        dims.append(Y.shape[1])

        dgp_model = DGP(X,
                        Y,
                        Z,
                        dims,
                        kernels,
                        Gaussian(variance=0.05),
                        num_samples=args.num_samples,
                        num_data=X.shape[0])

        # initialise inner layers almost deterministically
        for layer in dgp_model.layers[:-1]:
            layer.q_sqrt = Parameter(layer.q_sqrt.value() * 1e-5,
                                     transform=triangular())

        # =====================================================================
        # TRAINING
        # =====================================================================
        optimiser = tf.optimizers.Adam(args.learning_rate)

        print('Training DGP model...')
        t0 = time.time()
        # training loop
        monitored_training_loop(dgp_model,
                                train_dataset,
                                optimiser=optimiser,
                                logdir=args.log_dir,
                                iterations=args.iterations,
                                logging_iter_freq=args.logging_iter_freq)
        t1 = time.time()

        # =====================================================================
        # TESTING
        # =====================================================================
        test_times[i] = t1 - t0
        print('Time taken to train: {}'.format(t1 - t0))
        outfile3.write('Split {}: {}\n'.format(i + 1, t1 - t0))
        outfile3.flush()
        os.fsync(outfile3.fileno())
        running_time += t1 - t0

        # minibatch test predictions
        means, vars = [], []
        test_batch_size = args.test_batch_size
        if len(Xs) > test_batch_size:
            for mb in range(-(-len(Xs) // test_batch_size)):
                m, v = dgp_model.predict_y(Xs[mb * test_batch_size:(mb + 1) *
                                              test_batch_size, :],
                                           num_samples=args.test_samples)
                means.append(m)
                vars.append(v)
        else:
            m, v = dgp_model.predict_y(Xs, num_samples=args.test_samples)
            means.append(m)
            vars.append(v)

        mean_SND = np.concatenate(means, 1)  # [S, N, D]
        var_SND = np.concatenate(vars, 1)  # [S, N, D]
        mean_ND = np.mean(mean_SND, 0)  # [N, D]

        # rmse
        test_err = np.mean(Y_std * np.mean((Ys - mean_ND)**2.0)**0.5)
        test_errs[i] = test_err
        print('Average RMSE: {}'.format(test_err))
        outfile1.write('Split {}: {}\n'.format(i + 1, test_err))
        outfile1.flush()
        os.fsync(outfile1.fileno())
        running_err += test_err

        # nll
        test_nll = np.mean(
            logsumexp(norm.logpdf(Ys * Y_std, mean_SND * Y_std,
                                  var_SND**0.5 * Y_std),
                      0,
                      b=1 / float(args.test_samples)))
        test_nlls[i] = test_nll
        print('Average test log likelihood: {}'.format(test_nll))
        outfile2.write('Split {}: {}\n'.format(i + 1, test_nll))
        outfile2.flush()
        os.fsync(outfile2.fileno())
        running_loss += test_nll

    outfile1.write('Average: {}\n'.format(running_err / args.splits))
    outfile1.write('Standard deviation: {}\n'.format(np.std(test_errs)))
    outfile2.write('Average: {}\n'.format(running_loss / args.splits))
    outfile2.write('Standard deviation: {}\n'.format(np.std(test_nlls)))
    outfile3.write('Average: {}\n'.format(running_time / args.splits))
    outfile3.write('Standard deviation: {}\n'.format(np.std(test_times)))
    outfile1.close()
    outfile2.close()
    outfile3.close()
예제 #12
0
    def __init__(
        self,
        data: OutputData,
        split_space: bool, 
        Xp_mean: tf.Tensor,
        Xp_var: tf.Tensor,
        pi: tf.Tensor,
        kernel_K: List[Kernel],
        Zp: tf.Tensor,
        Xs_mean=None,
        Xs_var=None,
        kernel_s=None,
        Zs=None,
        Xs_prior_mean=None,
        Xs_prior_var=None,
        Xp_prior_mean=None,
        Xp_prior_var=None,
        pi_prior=None
    ):
        """
        Initialise Bayesian GPLVM object. This method only works with a Gaussian likelihood.

        :param data: data matrix, size N (number of points) x D (dimensions)
        :param: split_space, if true, have both shared and private space; 
            if false, only have private spaces (note: to recover GPLVM, set split_space=False and let K=1)
        :param Xp_mean: mean latent positions in the private space [N, Qp] (Qp is the dimension of the private space)
        :param Xp_var: variance of the latent positions in the private space [N, Qp]
        :param pi: mixture responsibility of each category to each point [N, K] (K is the number of categories), i.e. q(c)
        :param kernel_K: private space kernel, one for each category
        :param Zp: inducing inputs of the private space [M, Qp]
        :param num_inducing_variables: number of inducing points, M
        :param Xs_mean: mean latent positions in the shared space [N, Qs] (Qs is the dimension of the shared space). i.e. mus in q(Xs) ~ N(Xs | mus, Ss)
        :param Xs_var: variance of latent positions in shared space [N, Qs], i.e. Ss, assumed diagonal
        :param kernel_s: shared space kernel 
        :param Zs: inducing inputs of the shared space [M, Qs] (M is the number of inducing points)
        :param Xs_prior_mean: prior mean used in KL term of bound, [N, Qs]. By default 0. mean in p(Xs)
        :param Xs_prior_var: prior variance used in KL term of bound, [N, Qs]. By default 1. variance in p(Xs)
        :param Xp_prior_mean: prior mean used in KL term of bound, [N, Qp]. By default 0. mean in p(Xp)
        :param Xp_prior_var: prior variance used in KL term of bound, [N, Qp]. By default 1. variance in p(Xp)
        :param pi_prior: prior mixture weights used in KL term of the bound, [N, K]. By default uniform. p(c)        
        """

        # if don't want shared space, set shared space to none --> get a mixture of GPLVM
        # if don't want private space, set shared space to none, set K = 1 and only include 1 kernel in `kernel_K` --> recover the original GPLVM 

        # TODO: think about how to do this with minibatch
        # it's awkward since w/ minibatch the model usually doesn't store the data internally
        # but for gplvm, you need to keep the q(xn) for all the n's
        # so you need to know which ones to update for each minibatch, probably can be solved but not pretty
        # using inference network / back constraints will solve this, since we will be keeping a global set of parameters
        # rather than a set for each q(xn)
        self.N, self.D = data.shape
        self.Qp = Xp_mean.shape[1]
        self.K = pi.shape[1]
        self.split_space = split_space

        assert Xp_var.ndim == 2
        assert len(kernel_K) == self.K
        assert np.all(Xp_mean.shape == Xp_var.shape)
        assert Xp_mean.shape[0] == self.N, "Xp_mean and Y must be of same size"
        assert pi.shape[0] == self.N, "pi and Y must be of the same size"

        super().__init__()
        self.likelihood = likelihoods.Gaussian()
        self.kernel_K = kernel_K
        self.data = data_input_to_tensor(data)
        # the covariance of q(X) as a [N, Q] matrix, the assumption is that Sn's are diagonal
        # i.e. the latent dimensions are uncorrelated
        # otherwise would require a [N, Q, Q] matrix
        self.Xp_mean = Parameter(Xp_mean)
        self.Xp_var = Parameter(Xp_var, transform=positive())
        self.pi = Parameter(pi, transform=tfp.bijectors.SoftmaxCentered())
        self.Zp = inducingpoint_wrapper(Zp)
        self.M = len(self.Zp)

        # initialize the variational parameters for q(U), same way as in SVGP
        # q_mu: List[K], mean of the inducing variables U [M, D], i.e m in q(U) ~ N(U | m, S), 
        #   initialized as zeros
        # q_sqrt: List[K], cholesky of the covariance matrix of the inducing variables [D, M, M]
        #   q_diag is false because natural gradient only works for full covariance
        #   initialized as all identities
        # we need K sets of q(Uk), each approximating fs+fk
        self.q_mu = []
        self.q_sqrt = []
        for k in range(self.K):
            q_mu = np.zeros((self.M, self.D))
            q_mu = Parameter(q_mu, dtype=default_float())  # [M, D]
            self.q_mu.append(q_mu)

            q_sqrt = [
                np.eye(self.M, dtype=default_float()) for _ in range(self.D)
            ]
            q_sqrt = np.array(q_sqrt)
            q_sqrt = Parameter(q_sqrt, transform=triangular())  # [D, M, M]
            self.q_sqrt.append(q_sqrt)

        # deal with parameters for the prior 
        if Xp_prior_mean is None:
            Xp_prior_mean = tf.zeros((self.N, self.Qp), dtype=default_float())
        if Xp_prior_var is None:
            Xp_prior_var = tf.ones((self.N, self.Qp))
        if pi_prior is None:
            pi_prior = tf.ones((self.N, self.K), dtype=default_float()) * 1/self.K

        self.Xp_prior_mean = tf.convert_to_tensor(np.atleast_1d(Xp_prior_mean), dtype=default_float())
        self.Xp_prior_var = tf.convert_to_tensor(np.atleast_1d(Xp_prior_var), dtype=default_float()) 
        self.pi_prior = tf.convert_to_tensor(np.atleast_1d(pi_prior), dtype=default_float()) 


        # if we have both shared space and private space, need to initialize the parameters for the shared space
        if split_space:
            assert Xs_mean is not None and Xs_var is not None and kernel_s is not None and Zs is not None, 'Xs_mean, Xs_var, kernel_s, Zs need to be initialize if `split_space=True`'
            assert Xs_var.ndim == 2 
            assert np.all(Xs_mean.shape == Xs_var.shape)
            assert Xs_mean.shape[0] == self.N, "Xs_mean and Y must be of same size"
            self.Qs = Xs_mean.shape[1]
            self.kernel_s = kernel_s
            self.Xs_mean = Parameter(Xs_mean)
            self.Xs_var = Parameter(Xs_var, transform=positive())
            self.Zs = inducingpoint_wrapper(Zs)

            if len(Zs) != len(Zp):
                raise ValueError(
                    '`Zs` and `Zp` should have the same length'
                )

            if Xs_prior_mean is None:
                Xs_prior_mean = tf.zeros((self.N, self.Qs), dtype=default_float())
            if Xs_prior_var is None:
                Xs_prior_var = tf.ones((self.N, self.Qs))
            self.Xs_prior_mean = tf.convert_to_tensor(np.atleast_1d(Xs_prior_mean), dtype=default_float())
            self.Xs_prior_var = tf.convert_to_tensor(np.atleast_1d(Xs_prior_var), dtype=default_float())

        self.Fq = tf.zeros((self.N, self.K), dtype=default_float())