示例#1
0
    def predict(self, y, phi_gmm, encoder_layers, decoder_layers, seed=0):
        """
        Args:
            y: data to cluster and reconstruct
            phi_gmm: latent phi param
            encoder_layers: encoder NN architecture
            decoder_layers: encoder NN architecture
            seed: random seed

        Returns:
            reconstructed y and most probable cluster allocation
        """

        nb_samples = 1
        phi_enc_model = Encoder(layerspecs=encoder_layers)
        phi_enc = phi_enc_model.forward(y)

        x_k_samples, log_r_nk, _, _ = e_step(phi_enc,
                                             phi_gmm,
                                             nb_samples,
                                             seed=0)
        x_samples = subsample_x(x_k_samples, log_r_nk, seed)[:, 0, :]

        y_recon_model = Decoder(layerspecs=decoder_layers)
        y_mean, _ = y_recon_model.forward(x_samples)

        return (y_mean, torch.argmax(log_r_nk, dim=1))
示例#2
0
class PredictModel(nn.Module):
    def __init__(self,
                 model_type,
                 g,
                 num_nodes,
                 num_wkr,
                 num_tsk,
                 num_rels,
                 feat_dim,
                 mv_results,
                 num_heads,
                 dropout=0,
                 use_cuda=False,
                 reg_param=0,
                 e_dim=20,
                 feat_init=None):
        super(PredictModel, self).__init__()
        self.num_nodes = num_nodes
        self.num_wkr = num_wkr
        self.num_tsk = num_tsk
        self.feat_dim = feat_dim
        self.num_rels = num_rels
        self.mv_results = mv_results
        self.g = g
        self.model_type = model_type
        self.num_heads = num_heads
        self.feat_init = feat_init
        self.fc_trans_wkr = nn.Linear(num_tsk, feat_dim, bias=True)
        self.fc_trans_tsk = nn.Linear(num_wkr, feat_dim, bias=True)

        self.fc_edge_pred = nn.Linear(feat_dim * 2, num_rels, bias=True)
        self.fc_label_pred = nn.Linear(feat_dim, num_rels, bias=True)

        self.encoder = Encoder(model_type=model_type,
                               g=g,
                               num_nodes=num_nodes,
                               num_wkr=num_wkr,
                               num_tsk=num_tsk,
                               feat_dim=feat_dim,
                               num_rels=num_rels,
                               num_heads=num_heads)

    def predict_edge_score(self, ndata_h, triplets):
        wkr = triplets[:, 0]
        tsk = triplets[:, 2]
        wkr_feature = ndata_h[wkr.long()].squeeze(1)
        tsk_feature = ndata_h[tsk.long()].squeeze(1)

        edge_predict_score = self.fc_edge_pred(
            torch.cat((wkr_feature, tsk_feature), dim=1))
        return edge_predict_score

    def predict_label_score(self, ndata_h):
        tsk_node_feature = ndata_h[self.num_wkr:self.num_nodes]
        tsk_label_score = self.fc_label_pred(tsk_node_feature)
        return tsk_label_score

    def regularization_loss(self, embedding):
        return torch.mean(embedding.pow(2))

    def get_loss(self, same_feat, triplets, train_tsk_id, true_labels):
        if self.feat_init == "same":
            features = same_feat
        elif self.feat_init == "rand":
            features = torch.rand_like(same_feat)
        ndata_h = self.encoder.forward(features)
        wkr = triplets[:, 0]
        rel = triplets[:, 1]
        tsk = triplets[:, 2]

        loss = nn.CrossEntropyLoss()
        edge_predict_score = self.predict_edge_score(ndata_h, triplets)
        predict_edge_loss = loss(edge_predict_score, rel.squeeze().long())
        label_predict_score = self.predict_label_score(ndata_h)
        predict_label_loss = loss(label_predict_score[train_tsk_id].squeeze(1),
                                  true_labels[train_tsk_id])
        loss_sum = predict_label_loss
        return predict_label_loss, predict_edge_loss, loss_sum, label_predict_score
示例#3
0
class GMMSVAE(nn.Module):
    def __init__(self, opts, encoderlayers, decoderlayers, input_dim=784):
        super(GMMSVAE, self).__init__()
        self.device = opts.device
        self.encoder_layers = encoderlayers
        self.decoder_layers = decoderlayers
        self.decoder_type = decoderlayers[-1][1]
        self.nb_components = opts.nb_components
        self.nb_samples = opts.nb_samples
        self.latent_dims = opts.latent_dims
        self.batch_size = opts.batch_size
        self.seed = opts.seed
        self.input_dim = input_dim
        self.x_given_y_phi_model = Encoder(self.encoder_layers, input_dim)
        self.y_reconstruction_model = Decoder(
            self.decoder_layers,
            self.latent_dims)  # [type of layers, input_dim]
        self.gmm_prior, self.theta = self.init_mm(self.nb_components,
                                                  self.latent_dims, self.seed,
                                                  self.device)
        self.train_mu_k, self.train_L_k, self.train_pi_k = self.init_recognition_params(
            self.theta, self.nb_components, self.seed, self.device)

        self.totiter = 0

    def init_mm(self,
                nb_components,
                latent_dims,
                seed=0,
                param_device='cuda',
                theta_as_variable=True):
        '''
        Args:
        Returns:
            theta: Contains hyperparameters [alpha, A, b, beta, v_hat] where A, b, beta, v_hat are
            parameters of the NIW prior and alpha is the dirichlet prior on the parameters
        '''

        # prior parameters area always constance so the don't take gradients for them
        theta_prior = self.init_mm_params(nb_components,
                                          latent_dims,
                                          alpha_scale=0.05 / nb_components,
                                          beta_scale=0.5,
                                          m_scale=0,
                                          C_scale=latent_dims + 0.5,
                                          v_init=latent_dims + 0.5,
                                          seed=0,
                                          as_variables=False,
                                          trainable=False,
                                          device=param_device)

        theta = self.init_mm_params(nb_components,
                                    latent_dims,
                                    alpha_scale=1.,
                                    beta_scale=1.,
                                    m_scale=5.,
                                    C_scale=2 * (latent_dims),
                                    v_init=latent_dims + 1.,
                                    seed=0,
                                    as_variables=theta_as_variable,
                                    trainable=False)

        return theta_prior, theta

    def init_mm_params(self,
                       nb_components,
                       latent_dims,
                       alpha_scale=.1,
                       beta_scale=1e-5,
                       v_init=10.,
                       m_scale=1.,
                       C_scale=10.,
                       seed=0,
                       as_variables=True,
                       trainable=False,
                       device='cuda'):

        alpha_init = alpha_scale * torch.ones(nb_components,
                                              )  # shape [nb_components]
        beta_init = beta_scale * torch.ones(nb_components,
                                            )  # shape [nb_components]
        v_init = torch.tensor([float(latent_dims + v_init)]).expand(
            nb_components)  # shape [nb_components]
        means_init = m_scale * torch.empty(nb_components, latent_dims).uniform_(
            -1, 1
        )  # shape [nb_components, latent_dims]  - uniform random matrix between -1 to 1
        covariance_init = C_scale * torch.eye(latent_dims).expand(
            nb_components, -1,
            -1)  # shape nb_components x latent_dims x latent_dims

        A, b, beta, v_hat = niw.standard_to_natural(beta_init, means_init,
                                                    covariance_init, v_init)
        alpha = dirichlet.standard_to_natural(alpha_init)

        if as_variables:
            alpha = init_tensor_gpu_grad(alpha,
                                         trainable=trainable,
                                         device=device)
            A = init_tensor_gpu_grad(A, trainable=trainable, device=device)
            b = init_tensor_gpu_grad(b, trainable=trainable, device=device)
            beta = init_tensor_gpu_grad(beta,
                                        trainable=trainable,
                                        device=device)
            v_hat = init_tensor_gpu_grad(v_hat,
                                         trainable=trainable,
                                         device=device)

        return alpha, A, b, beta, v_hat

    def init_recognition_params(self,
                                theta,
                                nb_components,
                                seed=0,
                                param_device='cuda'):
        '''
        Args:
            theta [is a tuple that contains the following parameters]:
                alpha - the weights of the mixtures
                A - natural parameter of NIW
                b - natural parameter of NIW
                beta - parameter of NIW, is also called kappa
                v_hat - egree of freedom parameter of NIW
            nb_components: number of mixture components
        '''
        # make parameters for PGM part of recognition network

        mu_k, L_k = self.make_loc_scale_variables(theta, param_device)
        pi_k = torch.randn((self.nb_components, )).to(self.device)

        return mu_k, L_k, nn.Parameter(F.log_softmax(pi_k, dim=0))

    def make_loc_scale_variables(self, theta, param_device):
        '''
        This initalizes the prior of the encoder to a mixture model.  That is the output of the encoder network X ~ N(x|theta)
        theta is a conjugate prior of NIW currently with natural parameters, convert it to standard parameters which are trainable parameters.
        The gradients of the encoder control theta which try are regularized by the updates of the mixture model parameters.
        Args:
            theta [is a tuple that contains the following parameters]:
                alpha - the weights of the mixtures
                A - natural parameter of NIW
                b - natural parameter of NIW
                beta - parameter of NIW, is also called kappa
                v_hat - egree of freedom parameter of NIW
            param_device: location of where parameters are calculated
        '''
        theta_copied = niw.natural_to_standard(theta[1].clone(),
                                               theta[2].clone(),
                                               theta[3].clone(),
                                               theta[4].clone())

        mu_k_init, sigma_k = niw.expected_values(theta_copied)
        L_k_init = torch.cholesky(sigma_k)

        mu_k = init_tensor_gpu_grad(mu_k_init,
                                    trainable=True,
                                    device=param_device)
        L_k = init_tensor_gpu_grad(L_k_init,
                                   trainable=True,
                                   device=param_device)

        return mu_k, L_k

    #def forward(self, y, phi_gmm, encoder_layers, decoder_layers, nb_samples=10, stddev_init_nn=0.01, seed=0):
    def forward(self, y):

        # Assume currently MINST data set, where first index is data, second is labels, and data is sorted as Size x Image_Row x Image_Col
        # assert list(y.shape[-2:]) == [28, 28], "The INPUT is not MNIST"

        # Use VAE encoder
        # x_given_y_phi = self.x_given_y_phi_model.forward(y.view(-1, 784).to(self.device))
        x_given_y_phi = self.x_given_y_phi_model.forward(y)
        #print ("Finished Encoder Forward pass at iteration {}".format(self.totiter))

        # execute E-step (update/sample local variables)
        x_k_samples, log_z_given_y_phi, phi_tilde, w_eta_12 = self.e_step(
            x_given_y_phi, (self.train_mu_k, self.train_L_k, self.train_pi_k),
            self.nb_samples,
            seed=0)
        #print ("Finished E-step Forward pass at iteration: {}".format(self.totiter))
        # compute reconstruction

        y_reconstruction = self.y_reconstruction_model.forward(x_k_samples)
        #print ("Finished Decoder Forward pass at iteration: {}".format(self.totiter))
        #temp = torch.tensor(0,dtype=torch.int64)
        x_samples = self.subsample_x(x_k_samples, log_z_given_y_phi,
                                     seed=0)[:, 0, :]

        return y_reconstruction, x_given_y_phi, x_k_samples, x_samples, log_z_given_y_phi, (
            self.train_mu_k, self.train_L_k, self.train_pi_k), phi_tilde

    def e_step(self, phi_enc, phi_gmm, nb_samples, seed=0):
        """

        Args:
            phi_enc: encoded data; Base Measure Natural Parameters [In this case mean and variance for Gaussian]
            phi_gmm: parameters of the Graphical model, mu, variance, and the cluster weight
            nb_samples: number of ties to sample from q(x|z,y)
            seed: random seed

        Returns:

        """

        # Natural Parameter Vector of Encoder
        # [see http://www.robots.ox.ac.uk/~cvrg/michaelmas2004/VariationalInferenceAndVMP.pdf slide 31]
        eta1_phi1, eta2_phi1_diag = phi_enc
        # diagonalize the percision/variance [shapes goes from (2,4) to (2,4,4)]
        eta2_phi1 = torch.diag_embed(eta2_phi1_diag)

        #unpack cluster weight and natural parameters
        eta1_phi2, eta2_phi2, pi_phi2 = self.unpack_recognition_gmm(phi_gmm)

        # compute log q(z|y, phi)
        log_z_given_y_phi, dbg = self.compute_log_z_given_y(
            eta1_phi1, eta2_phi1, eta1_phi2, eta2_phi2, pi_phi2)

        # compute parameters phi_tilde -- equations 20-24 in Wu Lin, Emtiyaz Khan Structured Inference Networks Paper
        # eta2_phi_tilde = eta2_phi1 + eta2_phi2
        # eta1_phi_tilde = inv(eta2_phi_tilde) * (eta1_phi1 + eta1_phi2)
        # eta1_phi_tilde.shape = (N, K, D, 1); eta2_phi_tilde.shape = (N, K, D, D)
        eta2_phi_tilde = eta2_phi1.unsqueeze(1) + eta2_phi2.unsqueeze(0)
        eta1_phi_tilde = (eta1_phi1.unsqueeze(1) +
                          eta1_phi2.unsqueeze(0)).unsqueeze(
                              -1)  # without inv(eta2_phi_tilde)

        x_k_samples = self.sample_x_per_comp(eta1_phi_tilde,
                                             eta2_phi_tilde,
                                             nb_samples,
                                             seed=0)

        return x_k_samples, log_z_given_y_phi, (eta1_phi_tilde,
                                                eta2_phi_tilde), dbg

    def sample_x_per_comp(self, eta1, eta2, nb_samples, seed=0):
        """
        Args:
            eta1: 1st Gaussian natural parameter, shape = N, K, L, 1
            eta2: 2nd Gaussian natural parameter, shape = N, K, L, L
            nb_samples: nb of samples to generate for each of the K components
            seed: random seed

        Returns:
            x ~ N(x|eta1[k], eta2[k]), nb_samples times for each of the K components.
        """

        inv_sigma = -2. * eta2  # For reason see e_step calculation of eta1_phi_tilde
        N, K, _, D = eta2.shape

        # cholesky decomposition and adding noise (raw_noise is of dimension (DxB), where B is the size of MC samples)
        # Note cholesky decomposition that the lower triangle can be interperted as the square root of the matrix
        L = torch.cholesky(inv_sigma)  # sigma = sqrt(variance)
        #sample_shape = (N.int(), K.int(), D.int(), nb_samples)
        sample_shape = (self.batch_size, self.nb_components, self.latent_dims,
                        self.nb_samples)
        raw_noise = torch.randn(sample_shape).cuda()
        noise = L.transpose(dim0=3, dim1=2).inverse() @ raw_noise

        # reparam-trick-sampling: x_samps = mu_tilde + noise: shape = N, K, S, D (permute = N-dim transpose)
        x_k_samps = (inv_sigma.inverse() @ eta1 + noise).permute(0, 1, 3, 2)

        return x_k_samps

    def subsample_x(self, x_k_samples, log_q_z_given_y, seed=0):
        """
        Given S samples for each of the K components for N datapoints (x_k_samples) and q(z_n=k|y), subsample S samples for
        each data point
        Args:
            x_k_samples: sample matrix of shape (N, K, S, L) 
            log_q_z_given_y: probability q(z_n=k|y_n, phi) [Shape: N x K]
            seed: random seed
        Returns:
            x_samples: a sample matrix of shape (N, S, L)
        """

        N, K, S, L = x_k_samples.shape

        # prepare indices for N and S dimension
        n_idx = torch.arange(start=0, end=N).unsqueeze(1).repeat(
            1, S
        )  # S samples for each observation N, n_idx[0] = len([0,0,0,...,0]) = S
        s_idx = torch.arange(start=0, end=S).unsqueeze(0).repeat(
            N, 1
        )  # N Each observation has S samples, s_idx[0] = [0,1,2,...,S] -- N Times

        tempvar = log_q_z_given_y.detach().cpu()
        temp = tempvar.sum(dim=1)
        if (temp == 0).nonzero().nelement() != 0:
            print("TEST zamps")

        # Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values
        # so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
        # Suspect: S, Solution: change S to an int
        m = torch.distributions.Categorical(logits=tempvar)
        z_samps = torch.transpose(
            m.sample([self.nb_samples]), dim0=1,
            dim1=0)  # output of sampling is Sample Shape x Batch Size

        # Make sure all indexes are ints
        #z_samps = z_samps.to(torch.int64)

        # tensor of shape (N, S, 3), containing indices of all chosen samples
        # choices = torch.cat((n_idx.unsqueeze(2),z_samps.unsqueeze(2),s_idx.unsqueeze(2)),dim=2) --- DON'T NEED TO DO IN PYTORCH

        # select the chosen samples from x_k_samples, choices are the indices needed to extract from x_k_samples
        # So to paraphrase again, we have K components (from the GMM model) and S samples of each component, where each sample represents the parameters
        # of the latent dimensions, and what we want is S samples for each unique observation in the batch (N) such that the resulting matrix has
        # S samples of N observations

        # For example if we have a minibatch of 64, N = 64, GMM clusters of 10, K = 10, 10 Samples for every cluster, S = 10, and 3 Latent Dims from NN, then L = 6
        # So we have x_k_samples = [64, 10, 10, 6]  then n_idx represents getting S samples for observation, and s_idx represents which sample to get from the K-the component
        # the Kth-Component is chosen from z_samps, such that z_samps will have a length of S, so if z_samps = [9, 9, 9, 2, 9, 9, 9, 9, 8, 2], we will have
        # 7 samples from the 9th component, 2 samples from the 2nd component, and 1 sample from the 8th component

        # Replaced tf.gather_nd in tensorflow (from original code) with pytorch's advance indexing

        return x_k_samples[[n_idx, z_samps, s_idx]]

    def unpack_recognition_gmm(self, phi_gmm):
        """
        
        Args:
            phi_gmm: Contains the parameters of the graphical model, specifically, natural parameters of mean and precision
            and the cluster weight
        
        Returns:
            Returns a tuple with the natural parameter of the mean and precision (1/variance) and the cluster weight
        
        """

        eta1, L_k_raw, pi_k_raw = phi_gmm
        #temp = pi_k_raw.cpu().numpy()
        #np.savetxt('test_1_{}.txt'.format(self.totiter),temp,fmt='%1.4e',delimiter=',')

        # Computer Precision - the inverted Variance (1/sigma^2)
        # Make sure L_k_raw is a valid Cholesky decomposition A = LL*, where L is lower triangle
        # L* is conjugate tranpose of L
        L_k = torch.tril(
            L_k_raw)  # Returns batch of lower triangular part of matrix

        # Get diagonals of lower triangular (Note for batch inputs need to do :To take a batch diagonal, pass in dim1=-2, dim2=-1)
        # see https://pytorch.org/docs/stable/torch.html#torch.diagonal
        diag_L_k = torch.diagonal(L_k, dim1=-2, dim2=-1)
        softplus_L_k = F.softplus(
            diag_L_k
        )  # Softplus function to make sure everything is positive-definite

        # Need to set diagonal of original Variance matrix to Softplus values, so use mask
        # see: https://stackoverflow.com/questions/49429147/replace-diagonal-elements-with-vector-in-pytorch/49431180#49431180
        mask = torch.diag_embed(torch.ones_like(softplus_L_k))
        L_k = torch.diag_embed(softplus_L_k) + (
            1. - mask
        ) * L_k  # * is overloaded in pytorch for elemente wise multiplication

        # Compute Precision [Note @ = matmul]
        the_Precision = L_k @ torch.transpose(
            L_k, 2, 1
        )  #dim's 1 and 2 are the cluster parameters, dim 0 are the actual clusters

        # Compute natural parameter of precision
        eta2 = -0.5 * the_Precision

        # make sure that log_pi_k are valid mixture coefficients, softmax normalizes pi_k such that torch.exp(sum(pi_k))=1
        pi_k = F.log_softmax(pi_k_raw, dim=0)

        return (eta1, eta2, pi_k)

    def compute_log_z_given_y(self, eta1_phi1, eta2_phi1, eta1_phi2, eta2_phi2,
                              pi_phi2):
        """

        Args:
            eta1_phi1: encoder output; shape = N, K, L, requires_grad = True
            eta2_phi1: encoder output; shape = N, K, L, L, requires_grad = True
            eta1_phi2: GMM-EM parameter; shape = K, L, requires_grad = True
            eta2_phi2: GMM-EM parameter; shape = K, L, L, requires_grad = True
            where N = batch size, K = Number of Clusters, L = Number of Latent Variables

        Returns:
            log q(z|y, phi)
        """
        ''' Removing assertions for JIT
        N, L = eta1_phi1.shape # mean * precision
        assert list(eta2_phi1.shape) == [N, L, L]
        K, L2 = eta1_phi2.shape # mean * precision
        assert L2 == L
        assert list(eta2_phi2.shape) == [K, L, L] # 1/precision
        '''

        N = self.batch_size
        L = self.latent_dims
        L2 = L
        K = self.nb_components

        # Get Natural Parameters of Gaussian -- again see: http://www.robots.ox.ac.uk/~cvrg/michaelmas2004/VariationalInferenceAndVMP.pdf (slide 31)
        # eta2 = precision * -0.5
        # eta1 = mean * precision
        # where precision = inverse(variance)

        # z ~ N(mu_phi1|mu_phi2,sigma_phi1+sigma+phi2) [Lin, Kahn, VMP + SVAE pg. 6]
        # so percision is inv(sigma_ph1+sigma_phi2)
        # since we have natural parameters eta2_1 and eta2_2, we need to calculate natural parameter
        # inv(sigma_phi1+sigma_phi2) from eta2_1 and eta2_2 which is
        # [eta2_1*eta2_2] / [eta2_1 + eta2_2] = inverse(sigma_phi1 + sigma_phi2)

        # combine eta2_phi1 and eta2_phi2 - eta2_phi1 has dimensions mini-batch samples x latent x latent and eta2_phi2 has dimensions num_components x latent dim x latent dim
        # output is now mini_batch x num_components x latent_dim x latent_dim
        eta2_phi_tilde = eta2_phi1.unsqueeze(1) + eta2_phi2.unsqueeze(0)

        # calculate eta2_2 / inverse(eta2_2 + eta2_1) = inverse(eta2_2+eta2_1) * eta2_2 [shape:  mini_batch x num_components x latent_dim x latent_dim]
        inv_eta2_eta2_sum_eta1 = eta2_phi_tilde.inverse() @ eta2_phi2.expand(
            N, -1, -1, -1)

        # calculate eta2_1 * inv_sum_eta2_eta1 [shape:  mini_batch x num_components x latent_dim x latent_dim]
        # nju = mini_batch x latent_dim x latent_dim
        w_eta2 = torch.einsum('nju,nkui->nkij', eta2_phi1,
                              inv_eta2_eta2_sum_eta1)

        # Numerical Stability
        w_eta2 = (w_eta2 + w_eta2.transpose(dim0=-1, dim1=-2)) / 2.

        # now calculate the mean natural parameter [mean * precision]
        # remember precision is inv[sigma_phi1+sigma+phi2]

        # calculate [mu*precision_2] * (1 / eta2_2 + eta2_1)  --- Note eta1_phi2 = mu*precision_2 or mu/sigma_phi2
        mu_eta2_1_eta2_2 = eta2_phi_tilde.inverse() @ eta1_phi2.unsqueeze(
            0).unsqueeze(-1).expand(N, -1, -1, 1)  # Shape: NxKxLx1

        #Calculate eta2_1 * mu_eta2_1_eta2_2 = [mu*eta2_1*eta2_2]/(eta2_1+eta2_2)
        w_eta1 = torch.einsum('nuj,nkuv->nkj', eta2_phi1,
                              mu_eta2_1_eta2_2)  # Shape: NxKxL

        # compute means of the encoder network
        mu_phi1, _ = gaussian.natural_to_standard(
            eta1_phi1, eta2_phi1
        )  # Remember the observed data are the means of recognition network (encoder output)

        # computer log_z_given_y_phi Lin, Kahn, VMP + SVAE pg. 11 ep 23]
        return gaussian.log_probability_nat(mu_phi1, w_eta1, w_eta2,
                                            pi_phi2), (w_eta1, w_eta2)

    def compute_elbo(self, y, reconstructions, theta, phi_tilde, x_k_samps,
                     log_z_given_y_phi, decoder_type):
        """
        Compute ELBO of Latent GMM 
        Args:
            y: original data
            reconstructions: reconststructed y
            theta: hyperparameters of GMM model 
                [alpha: prior Dirichlet parameters, beta/kappa: prior NiW, 
                controls variance of mean, m: prior of mean, c: prior of covariance, v: prior degrees of freedom ]
            phi_tilde: Natural Parameters of GMM
            x_k_samps: Latent Vectors produced from GMM
            log_z_given_y_phi: Mixture probabilities # Shape: N x K 
            decoder_type: Gaussian or Bernoulli decoder

        Returns:
            ELBO: evidence lower bound of reconstruction and KL divergence of prior and variational prior
            Details: Tuple of negative reconstruction error, numberator of regularizer, denominator of regulaizer, regualizer term
        """

        beta_k, m_k, C_k, v_k = niw.natural_to_standard(*theta[1:])
        mu, sigma = niw.expected_values((beta_k, m_k, C_k, v_k))
        eta1_theta, eta2_theta = gaussian.standard_to_natural(mu, sigma)
        alpha_k = dirichlet.natural_to_standard(theta[0])
        expected_log_pi_theta = dirichlet.expected_log_pi(alpha_k)

        # Don't backprop through GMM
        eta1_theta = eta1_theta.detach()
        eta2_theta = eta2_theta.detach()
        expected_log_pi_theta = expected_log_pi_theta.detach()

        r_nk = torch.exp(log_z_given_y_phi)

        # compute negative reconstruction error; sum over minibatch (use VAE function)
        means_recon, out_2_recon = reconstructions  # out_2 is gaussian variances
        if decoder_type == 'standard':
            self.neg_reconstruction_error = self.expected_diagonal_gaussian_loglike(
                y, means_recon, out_2_recon, weights=r_nk)
        else:
            raise NotImplementedError

        # compute E[log q_phi(x,z=k|y)]
        eta1_phi_tilde, eta2_phi_tilde = phi_tilde
        N, K, L, _ = eta2_phi_tilde.shape
        eta1_phi_tilde = torch.reshape(eta1_phi_tilde, (N, K, L))

        N, K, S, L = x_k_samps.shape

        # Computer Log-Numerator see: Variational Message Parsing with Structured Inference Networks pg. 5 Equations 7 - 10
        # Log-Numerator = log[ PROD(p(y_n|x_n, theta_NN) * p(x|theta_PGM) * Z(phi)]
        # Log-Denominator = log[ PROD(q(x_n|f_phi_nn(y_n)*q(x|phi_PGM))]

        # Note p(y_n|x_n, theta_NN) / q(x_n|f_phi_nn(y_n) is the RECONSTRUCTION ERROR
        # The unique parts of this ELBO are E_q[log p(x|theta_PGM)] - E_q[log q(x|phi_PGM)] - log Z(phi)

        # For GMM Z(phi) = sum_{1 to K}(N(m_n|mean_tilde_k, V_n+sigma_tilde_k) * pi_k
        # where m_n = mean of encoder, V_n = variance of encoder
        # sigma_tilde_n = inverse(V_n) + inverse(sigma_tilde_k)
        # mean_tilde_n = sigma_tilde_n * (inverse(V_n)*m_n + inverse(sigma_tilde_k)*mean_tilde_k)
        # This results in Z(phi) = sum_{1 to K}(log_z_given_y_phi)

        # Log Numerator = q(x,z|y) = q(x|z, y, phi)*q(z|y,phi) = N(x_n|mean_tilde_n,sigma_tilde_n)*N(m_n|mean_tild_k,V_n+sigma_tilde_k)
        log_N_x_given_phi = gaussian.log_probability_nat_per_samp(
            x_k_samps, eta1_phi_tilde, eta2_phi_tilde)  # Shape: N x K x L
        log_numerator = log_N_x_given_phi + log_z_given_y_phi.unsqueeze(
            2)  # Since q(z|y,phi) is only of shape N x K

        log_N_x_given_theta = gaussian.log_probability_nat_per_samp(
            x_k_samps,
            eta1_theta.unsqueeze(0).expand(N, -1, -1),
            eta2_theta.expand(N, -1, -1, -1))  # Shape: N x K x L
        log_denominator = log_N_x_given_theta + expected_log_pi_theta.unsqueeze(
            0).unsqueeze(2)

        regualizer_term_part_1 = r_nk.unsqueeze(2) * (log_numerator -
                                                      log_denominator)
        regualizer_term_part_2 = torch.sum(regualizer_term_part_1, dim=1)
        regualizer_term_part_3 = torch.sum(regualizer_term_part_2, dim=0)
        self.regualizer_term_final = torch.mean(regualizer_term_part_3)

        elbo = -1. * (self.neg_reconstruction_error -
                      self.regualizer_term_final)

        details = (self.neg_reconstruction_error,
                   torch.sum(r_nk * torch.mean(log_numerator, -1)),
                   torch.sum(r_nk * torch.mean(log_denominator, -1)),
                   self.regualizer_term_final)
        self.totiter += 1

        return elbo, details

    def compute_elbo_debug(self, y, reconstructions, theta, phi_tilde,
                           x_k_samps, log_z_given_y_phi, decoder_type):
        """
        Compute Reconstruction Error  -- For debugging purposes
        Args:
            y: original data
            reconstructions: reconststructed y
            theta: hyperparameters of GMM model 
                [alpha: prior Dirichlet parameters, beta/kappa: prior NiW, 
                controls variance of mean, m: prior of mean, c: prior of covariance, v: prior degrees of freedom ]
            phi_tilde: Natural Parameters of GMM
            x_k_samps: Latent Vectors produced from GMM
            log_z_given_y_phi: Mixture probabilities # Shape: N x K 
            decoder_type: Gaussian or Bernoulli decoder

        Returns:
            ELBO: evidence lower bound of reconstruction and KL divergence of prior and variational prior
            Details: Tuple of negative reconstruction error, numberator of regularizer, denominator of regulaizer, regualizer term
        """

        # Don't backprop through GMM
        r_nk = torch.exp(log_z_given_y_phi)

        # compute negative reconstruction error; sum over minibatch (use VAE function)
        means_recon, out_2_recon = reconstructions  # out_2 is gaussian variances
        if decoder_type == 'standard':
            self.neg_reconstruction_error = self.expected_diagonal_gaussian_loglike(
                y.to(self.device), means_recon, out_2_recon, weights=r_nk)
        else:
            raise NotImplementedError

        elbo = -1. * (self.neg_reconstruction_error)
        self.totiter += 1

        return elbo

    def expected_diagonal_gaussian_loglike(self,
                                           y,
                                           param1_recon,
                                           param2_recon,
                                           weights=None):
        """
        computes expected diagonal log-likelihood SUM_{n=1} E_{q(z)}[log N(x_n|mu(z), sigma(z))]
        Args:
            y: data
            param1_recon: predicted means; shape (size_minibatch, nb_samples, dims) or (size_minimbatch, nb_comps, nb_samps, dims)
            param2_recon: predicted variances; shape is same as for means
            weights: None or matrix of shape (N, K) containing normalized weights

        Returns:

        """

        if weights is None:
            # required dimension: size_minibatch, nb_samples, dims

            param1_recon = param1_recon if len(
                param1_recon.shape) == 3 else param1_recon.unsqueeze(1)
            param2_recon = param2_recon if len(
                param2_recon.shape) == 3 else param2_recon.unsqueeze(1)
            M, S, L = param1_recon.shape
            assert list(y.shape) == [M, L]

            sample_mean = torch.sum(
                torch.pow(y.unsqueeze(1) - param1_recon, 2) /
                param2_recon) + torch.sum(torch.log(param2_recon))

            S = torch.tensor(int(S), dtype=torch.float32, requires_grad=False)
            M = torch.tensor(int(M), dtype=torch.float32, requires_grad=False)
            L = torch.tensor(int(L), dtype=torch.float32, requires_grad=False)
            pi = torch.tensor(np.pi, dtype=torch.float32, requires_grad=False)

            sample_mean /= S
            loglik = -1 / 2 * sample_mean - M * L / 2. * torch.log(2. * pi)

        else:
            M, K, S, L = param1_recon.shape
            assert param2_recon.shape == param1_recon.shape
            assert list(weights.shape) == [M, K]

            # adjust y's shape (add component and sample dimensions)
            y = y.unsqueeze(1).unsqueeze(1)

            sample_mean = torch.einsum(
                'nksd,nk->',
                torch.pow(y - param1_recon, 2) / param2_recon +
                torch.log(param2_recon + 1e-8), weights)

            S = torch.tensor(int(S), dtype=torch.float32,
                             requires_grad=False).cuda()
            M = torch.tensor(int(M), dtype=torch.float32,
                             requires_grad=False).cuda()
            L = torch.tensor(int(L), dtype=torch.float32,
                             requires_grad=False).cuda()
            pi = torch.tensor(np.pi, dtype=torch.float32,
                              requires_grad=False).cuda()

            sample_mean /= S
            loglik = -1 / 2 * sample_mean - M * L / 2. * torch.log(2. * pi)

        return loglik

    def update_gmm_params(self, current_gmm_params, gmm_params_star,
                          step_size):
        """
        Computes convex combination between current and updated parameters.
        Args:
            current_gmm_params: current parameters
            gmm_params_star: parameters received by GMM-EM algorithm
            step_size: step size for convex combination
            name:

        Returns:
        """
        a, b, c, d, e = current_gmm_params
        step_size = torch.from_numpy(np.array(step_size)).cuda()

        current_gmm_params = [
            (1 - step_size) * curr_param + step_size * param_star
            for (curr_param,
                 param_star) in zip(current_gmm_params, gmm_params_star)
        ]

        return current_gmm_params

    def predict(self, y, phi_gmm, encoder_layers, decoder_layers, seed=0):
        """
        Args:
            y: data to cluster and reconstruct
            phi_gmm: latent phi param
            encoder_layers: encoder NN architecture
            decoder_layers: encoder NN architecture
            seed: random seed

        Returns:
            reconstructed y and most probable cluster allocation
        """

        nb_samples = 1
        phi_enc_model = Encoder(layerspecs=encoder_layers)
        phi_enc = phi_enc_model.forward(y)

        x_k_samples, log_r_nk, _, _ = e_step(phi_enc,
                                             phi_gmm,
                                             nb_samples,
                                             seed=0)
        x_samples = subsample_x(x_k_samples, log_r_nk, seed)[:, 0, :]

        y_recon_model = Decoder(layerspecs=decoder_layers)
        y_mean, _ = y_recon_model.forward(x_samples)

        return (y_mean, torch.argmax(log_r_nk, dim=1))

    def m_step(self, gmm_prior, x_samples, r_nk):
        """
        Args:
            gmm_prior: Dirichlet+NiW prior for Gaussian mixture model
            x_samples: samples of shape (N, S, L)
            r_nk: responsibilities of shape (N, K)

        Returns:
            Dirichlet+NiW parameters obtained by executing Bishop's M-step in the VEM algorithm for GMMs
        """

        # execute GMM-EM m-step
        beta_0, m_0, C_0, v_0 = niw.natural_to_standard(*gmm_prior[1:])
        alpha_0 = dirichlet.natural_to_standard(gmm_prior[0])

        alpha_k, beta_k, m_k, C_k, v_k, x_k, S_k = gmm.m_step(
            x_samples, r_nk, alpha_0, beta_0, m_0, C_0, v_0)

        A, b, beta, v_hat = niw.standard_to_natural(beta_k, m_k, C_k, v_k)
        alpha = dirichlet.standard_to_natural(alpha_k)

        return (alpha, A, b, beta, v_hat)