예제 #1
0
def compute_z_moments(w_s, eta_old, H_old, psi_old):
    ''' Compute the first moment and the variance of the latent variable 
    w_s (list of length s1): The path probabilities for all s in S1
    eta_old (list of nb_layers elements of shape (K_l x r_{l-1}, 1)): eta  
                        estimators of the previous iteration for each layer
    H_old (list of nb_layers elements of shape (K_l x r_l-1, r_l)): Lambda 
                        estimators of the previous iteration for each layer
    psi_old (list of nb_layers elements of shape (K_l x r_l-1, r_l-1)): Psi 
                        estimators of the previous iteration for each layer
    -------------------------------------------------------------------------
    returns (tuple of length 2): E(z^{(l)}) and Var(z^{(l)})  
    '''

    k = [eta.shape[0] for eta in eta_old]
    L = len(eta_old)

    Ez = [[] for l in range(L)]
    AT = [[] for l in range(L)]

    w_reshaped = w_s.reshape(*k, order='C')

    for l in reversed(range(L)):
        # Compute E(z^{(l)})
        idx_to_sum = tuple(set(range(L)) - set([l]))

        wl = w_reshaped.sum(idx_to_sum)[..., n_axis, n_axis]
        Ezl = (wl * eta_old[l]).sum(0, keepdims=True)
        Ez[l] = Ezl

        etaTeta = eta_old[l] @ t(eta_old[l], (0, 2, 1))
        HlHlT = H_old[l] @ t(H_old[l], (0, 2, 1))

        E_zlzlT = (wl * (HlHlT + psi_old[l] + etaTeta)).sum(0, keepdims=True)
        var_zl = E_zlzlT - Ezl @ t(Ezl, (0, 2, 1))

        try:
            var_zl = ensure_psd([var_zl])[0]  # Numeric stability check
        except:
            print(var_zl)
            raise RuntimeError('Var z1 was not psd')

        AT_l = cholesky(var_zl)
        AT[l] = AT_l

    return Ez, AT
예제 #2
0
def MDGMM(y, n_clusters, r, k, init, var_distrib, nj, it = 50, \
          eps = 1E-05, maxstep = 100, seed = None, perform_selec = True): 
    
    ''' Fit a Generalized Linear Mixture of Latent Variables Model (GLMLVM)
    
    y (numobs x p ndarray): The observations containing mixed variables
    n_clusters (int or str): The number of clusters to look for in the data or the use mode of the MDGMM
    r (dict): The dimension of latent variables through the first 2 layers
    k (dict): The number of components of the latent Gaussian mixture layers
    init (dict): The initialisation parameters for the algorithm
    var_distrib (p 1darray): An array containing the types of the variables in y 
    nj (p 1darray): For binary/count data: The maximum values that the variable can take. 
                    For ordinal data: the number of different existing categories for each variable
                    For categorical data: the number of different existing categories for each variable
    it (int): The maximum number of MCEM iterations of the algorithm
    eps (float): If the likelihood increase by less than eps then the algorithm stops
    maxstep (int): The maximum number of optimisation step for each variable
    seed (int): The random state seed to set (Only for numpy generated data for the moment)
    perform_selec (Bool): Whether to perform architecture selection or not
    ------------------------------------------------------------------------------------------------
    returns (dict): The predicted classes, the likelihood through the EM steps
                    and a continuous representation of the data
    '''
    
    # Break the reference link 
    k = deepcopy(k)
    r = deepcopy(r)
    
    best_k = deepcopy(k)
    best_r = deepcopy(r)

    # Add other checks for the other variables
    check_inputs(k, r)

    prev_lik = - 1E15
    best_lik = -1E15
    
    tol = 0.01
    max_patience = 1
    patience = 0
    
    #====================================================
    # Initialize the parameters
    #====================================================
        
    eta_c, eta_d, H_c, H_d, psi_c, psi_d = dispatch_dgmm_init(init)
    lambda_bin, lambda_ord, lambda_categ = dispatch_gllvm_init(init)
    w_s_c, w_s_d = dispatch_paths_init(init)
    
    numobs = len(y)
    likelihood = []
    it_num = 0
    ratio = 1000
    np.random.seed = seed

    #====================================================        
    # Dispatch variables between categories
    #====================================================

    y_bin = y[:, np.logical_or(var_distrib == 'bernoulli',\
                               var_distrib == 'binomial')]
    nj_bin = nj[np.logical_or(var_distrib == 'bernoulli',\
                              var_distrib == 'binomial')]
        
    nj_bin = nj_bin.astype(int)
    nb_bin = len(nj_bin)
        
    y_ord = y[:, var_distrib == 'ordinal']    
    nj_ord = nj[var_distrib == 'ordinal']
    nj_ord = nj_ord.astype(int)
    nb_ord = len(nj_ord)
    
    y_categ = y[:, var_distrib == 'categorical']
    nj_categ = nj[var_distrib == 'categorical'].astype(int)
    nb_categ = len(nj_categ)    
    
    yc = y[:, var_distrib == 'continuous'] 
    
    ss = StandardScaler()
    yc = ss.fit_transform(yc)

    nb_cont = yc.shape[1]
    
    # *_1L standsds for quantities going through all the network (head + tail)
    k_1L, L_1L, L, bar_L, S_1L = nb_comps_and_layers(k)    
    r_1L = {'c': r['c'] + r['t'], 'd': r['d'] + r['t'], 't': r['t']}
    
    best_sil = [-1.1 for l in range(L['t'] - 1)] if n_clusters == 'multi' else -1.1 
    new_sil = [-1.1 for l in range(L['t'] - 1)] if n_clusters == 'multi' else -1.1 
    
    
    M = M_growth(1, r_1L, numobs) 

    if nb_bin + nb_ord + nb_categ == 0: # Create the InputError class and change this
        raise ValueError('Input does not contain discrete variables,\
                         consider using a regular DGMM')
    if nb_cont == 0: # Create the InputError class and change this
        raise ValueError('Input does not contain continuous values,\
                         consider using a DDGMM')
                         
                         
    # Compute the Gower matrix
    cat_features = np.logical_or(var_distrib == 'categorical', var_distrib == 'bernoulli')
    dm = gower_matrix(y, cat_features = cat_features)
                     
    while (it_num < it) & ((ratio > eps) | (patience <= max_patience)):
        print(it_num)

        # The clustering layer is the one used to perform the clustering 
        # i.e. the layer l such that k[l] == n_clusters
        if not(isnumeric(n_clusters)):
            if n_clusters == 'auto':
                clustering_layer = 0
            elif n_clusters == 'multi':
                clustering_layer = list(range(L['t'] - 1))
            else:
                raise ValueError('Please enter an int, auto or multi for n_clusters')
        else:
            assert (np.array(k['t']) == n_clusters).any()
            clustering_layer = np.argmax(np.array(k['t']) == n_clusters)

        #####################################################################################
        ################################# MC step ############################################
        #####################################################################################

        #=====================================================================
        # Draw from f(z^{l} | s, Theta) for both heads and tail
        #=====================================================================  
        
        mu_s_c, sigma_s_c = compute_path_params(eta_c, H_c, psi_c)
        sigma_s_c = ensure_psd(sigma_s_c)
        
        mu_s_d, sigma_s_d = compute_path_params(eta_d, H_d, psi_d)
        sigma_s_d = ensure_psd(sigma_s_d)
                        
        z_s_c, zc_s_c, z_s_d, zc_s_d = draw_z_s_all_network(mu_s_c, sigma_s_c,\
                            mu_s_d, sigma_s_d, yc, eta_c, eta_d, S_1L, L, M)
                    
        #========================================================================
        # Draw from f(z^{l+1} | z^{l}, s, Theta) for l >= 1
        #========================================================================
        
        # Create wrapper as before and after
        chsi_c = compute_chsi(H_c, psi_c, mu_s_c, sigma_s_c)
        chsi_c = ensure_psd(chsi_c)
        rho_c = compute_rho(eta_c, H_c, psi_c, mu_s_c, sigma_s_c, zc_s_c, chsi_c)
        
                
        chsi_d = compute_chsi(H_d, psi_d, mu_s_d, sigma_s_d)
        chsi_d = ensure_psd(chsi_d)
        rho_d = compute_rho(eta_d, H_d, psi_d, mu_s_d, sigma_s_d, zc_s_d, chsi_d)


        # In the following z2 and z1 will denote z^{l+1} and z^{l} respectively
        z2_z1s_c, z2_z1s_d = draw_z2_z1s_network(chsi_c, chsi_d, rho_c, \
                                                 rho_d, M, r_1L, L)
        
        #=======================================================================
        # Compute the p(y^D| z1) for all discrete variables
        #=======================================================================
        
        py_zl1_d = fy_zl1(lambda_bin, y_bin, nj_bin, lambda_ord, y_ord, nj_ord,\
                          lambda_categ, y_categ, nj_categ, z_s_d[0])
        
        #========================================================================
        # Draw from p(z1 | y, s) proportional to p(y | z1) * p(z1 | s) for all s
        #========================================================================
                
        zl1_ys_d = draw_zl1_ys(z_s_d, py_zl1_d, M['d'])
                
        #####################################################################################
        ################################# E step ############################################
        #####################################################################################
        
        #=====================================================================
        # Compute quantities necessary for E steps of both heads and tail
        #=====================================================================
        
        # Discrete head quantities
        pzl1_ys_d, ps_y_d, py_d = E_step_GLLVM(z_s_d[0], mu_s_d[0], sigma_s_d[0], w_s_d, py_zl1_d)        
        py_s_d = ps_y_d * py_d / w_s_d[n_axis]
        
        # Continuous head quantities
        ps_y_c, py_s_c, py_c = continuous_lik(yc, mu_s_c[0], sigma_s_c[0], w_s_c)
        
        pz_s_d = fz_s(z_s_d, mu_s_d, sigma_s_d) 
        pz_s_c = fz_s(z_s_c, mu_s_c, sigma_s_c) 
        
        #=====================================================================
        # Compute p(z^{(l)}| s, y). Equation (5) of the paper
        #=====================================================================
        
        # Compute pz2_z1s_d and pz2_z1s_d for the tail indices whereas it is useless
        
        pz2_z1s_d = fz2_z1s(t(pzl1_ys_d, (1, 0, 2)), z2_z1s_d, chsi_d, rho_d, S_1L['d'])
        pz_ys_d = fz_ys(t(pzl1_ys_d, (1, 0, 2)), pz2_z1s_d)
          
        pz2_z1s_c = fz2_z1s([], z2_z1s_c, chsi_c, rho_c, S_1L['c'])
        pz_ys_c = fz_ys([], pz2_z1s_c)
        
        pz2_z1s_t = fz2_z1s([], z2_z1s_c[bar_L['c']:], chsi_c[bar_L['c']:], \
                            rho_c[bar_L['c']:], S_1L['t'])

        # Junction layer computations
        # Compute p(zC |s)
        py_zs_d = fy_zs(pz_ys_d, py_s_d) 
        py_zs_c = fy_zs(pz_ys_c, py_s_c)
         
        # Compute p(zt | yC, yD, sC, SD)        
        pzt_yCyDs = fz_yCyDs(py_zs_c, pz_ys_d, py_s_c, M, S_1L, L)

        #=====================================================================
        # Compute MFA expectations
        #=====================================================================
        
        # Discrete head. 
        Ez_ys_d, E_z1z2T_ys_d, E_z2z2T_ys_d, EeeT_ys_d = \
            E_step_DGMM_d(zl1_ys_d, H_d, z_s_d, zc_s_d, z2_z1s_d, pz_ys_d,\
                        pz2_z1s_d, S_1L['d'], L['d'])
        
            
        # Continuous head
        Ez_ys_c, E_z1z2T_ys_c, E_z2z2T_ys_c, EeeT_ys_c = \
            E_step_DGMM_c(H_c, z_s_c, zc_s_c, z2_z1s_c, pz_ys_c,\
                          pz2_z1s_c, S_1L['c'], L['c'])


        # Junction layers
        Ez_ys_t, E_z1z2T_ys_t, E_z2z2T_ys_t, EeeT_ys_t = \
            E_step_DGMM_t(H_c[bar_L['c']:], \
            z_s_c[bar_L['c']:], zc_s_c[bar_L['c']:], z2_z1s_c[bar_L['c']:],\
                pzt_yCyDs, pz2_z1s_t, S_1L, L, k_1L)  
        
        # Error here for the first two terms: p(y^h | z^t, s^C) != p(y^h | z^t, s^{1C:L})
        pst_yCyD = fst_yCyD(py_zs_c, py_zs_d, pz_s_d, w_s_c, w_s_d, k_1L, L)   
                               
        ###########################################################################
        ############################ M step #######################################
        ###########################################################################

        #=======================================================
        # Compute DGMM Parameters 
        #=======================================================
            
        # Discrete head
        w_s_d = np.mean(ps_y_d, axis = 0)      
        eta_d_barL, H_d_barL, psi_d_barL = M_step_DGMM(Ez_ys_d, E_z1z2T_ys_d, E_z2z2T_ys_d, \
                                        EeeT_ys_d, ps_y_d, H_d, k_1L['d'][:-1],\
                                            L_1L['d'], r_1L['d'])
         
        # Add dispatching function here
        eta_d[:bar_L['d']] = eta_d_barL
        H_d[:bar_L['d']] = H_d_barL
        psi_d[:bar_L['d']] = psi_d_barL
                
        # Continuous head
        w_s_c = np.mean(ps_y_c, axis = 0)  
        eta_c_barL, H_c_barL, psi_c_barL = M_step_DGMM(Ez_ys_c, E_z1z2T_ys_c, E_z2z2T_ys_c, \
                                        EeeT_ys_c, ps_y_c, H_c, k_1L['c'][:-1],\
                                            L_1L['c'] + 1, r_1L['c'])
        
        eta_c[:bar_L['c']] = eta_c_barL
        H_c[:bar_L['c']] = H_c_barL
        psi_c[:bar_L['c']] = psi_c_barL
                    

        # Common tail
        eta_t, H_t, psi_t, Ezst_y = M_step_DGMM_t(Ez_ys_t, E_z1z2T_ys_t, E_z2z2T_ys_t, \
                                        EeeT_ys_t, ps_y_c, ps_y_d, pst_yCyD, \
                                            H_c[bar_L['c']:], S_1L, k_1L, \
                                            L_1L, L, r_1L['t'])  
            
        eta_d[bar_L['d']:] = eta_t
        H_d[bar_L['d']:] = H_t
        psi_d[bar_L['d']:] = psi_t            

        eta_c[bar_L['c']:] = eta_t
        H_c[bar_L['c']:] = H_t
        psi_c[bar_L['c']:] = psi_t  
                         
        #=======================================================
        # Identifiability conditions
        #=======================================================
        w_s_t = np.mean(pst_yCyD, axis = 0)  
        eta_d, H_d, psi_d, AT_d, eta_c, H_c, psi_c, AT_c = network_identifiability(eta_d, \
                                H_d, psi_d, eta_c, H_c, psi_c, w_s_c, w_s_d, w_s_t, bar_L)
                
        #=======================================================
        # Compute GLLVM Parameters
        #=======================================================
        
        # We optimize each column separately as it is faster than all column jointly 
        # (and more relevant with the independence hypothesis)
                
        lambda_bin = bin_params_GLLVM(y_bin, nj_bin, lambda_bin, ps_y_d, \
                    pzl1_ys_d, z_s_d[0], AT_d[0], tol = tol, maxstep = maxstep)
                 
        lambda_ord = ord_params_GLLVM(y_ord, nj_ord, lambda_ord, ps_y_d, \
                    pzl1_ys_d, z_s_d[0], AT_d[0], tol = tol, maxstep = maxstep)
            
        lambda_categ = categ_params_GLLVM(y_categ, nj_categ, lambda_categ, ps_y_d,\
                    pzl1_ys_d, z_s_d[0], AT_d[0], tol = tol, maxstep = maxstep)

        ###########################################################################
        ################## Clustering parameters updating #########################
        ###########################################################################
          
        new_lik = np.sum(np.log(py_d) + np.log(py_c))
        likelihood.append(new_lik)
        ratio = (new_lik - prev_lik)/abs(prev_lik)
        
        
        if n_clusters == 'multi':
            temp_classes = [] 
            z_tail = []
            classes = [[] for l in range(L['t'] - 1)]
            
            for l in clustering_layer:
                idx_to_sum = tuple(set(range(1, L['t'] + 1)) -\
                                   set([clustering_layer[l] + 1]))
                psl_y = pst_yCyD.reshape(numobs, *k['t'],\
                                         order = 'C').sum(idx_to_sum)
                
                temp_class_l = np.argmax(psl_y, axis = 1)
                sil_l = silhouette_score(dm, temp_class_l, metric = 'precomputed')
                    
                temp_classes.append(temp_class_l)
                #z_tail.append(Ezst_y[l].sum(1))
                new_sil[l] = sil_l
            
            #z_tail = []
            for l in range(L['t'] - 1):
                zl = Ezst_y[l].sum(1)
                z_tail.append(zl)
                    
                if best_sil[l] < new_sil[l]:
                    # Update the quantity if the silhouette score is better 
                    best_sil[l] = deepcopy(new_sil[l])
                    classes[l] = deepcopy(temp_classes[l])
                    
                    if zl.shape[-1] == 3:
                        plot_3d(zl, classes[l])
                    elif zl.shape[-1] == 2:
                        plot_2d(zl, classes[l])
           
        else: 
            idx_to_sum = tuple(set(range(1, L['t'] + 1)) - set([clustering_layer + 1]))
            psl_y = pst_yCyD.reshape(numobs, *k['t'], order = 'C').sum(idx_to_sum) 
        
            temp_classes = np.argmax(psl_y, axis = 1) 
            try:
                new_sil = silhouette_score(dm, temp_classes, metric = 'precomputed') 
            except:
                new_sil = -1
            
            z_tail = [Ezst_y[l].sum(1) for l in range(L['t'] - 1)]
                             
            if best_sil < new_sil:
                # Update the quantity if the silhouette score is better 
                zl = z_tail[clustering_layer]
                best_sil = deepcopy(new_sil)
                classes = deepcopy(temp_classes)
                
                if zl.shape[-1] == 3:
                    plot_3d(zl, classes)
                elif zl.shape[-1] == 2:
                    plot_2d(zl, classes)
        
        # Refresh the likelihood if best
        if best_lik < new_lik:
            best_lik = deepcopy(prev_lik)
      
        if prev_lik < new_lik:
            patience = 0
            M = M_growth(it_num + 1, r_1L, numobs)
        else:
            patience += 1
                       
        ###########################################################################
        ######################## Parameter selection  #############################
        ###########################################################################
                    
        min_nb_clusters = 2
        is_not_min_specif = not(is_min_architecture_reached(k, r, min_nb_clusters))
        
        if look_for_simpler_network(it_num) & perform_selec & is_not_min_specif:
            
            # To add: selection according to categ
            r_to_keep = r_select(y_bin, y_ord, y_categ, yc, zl1_ys_d,\
                                 z2_z1s_d[:bar_L['d']], w_s_d, z2_z1s_c[:bar_L['c']],
                                 z2_z1s_c[bar_L['c']:], n_clusters)
            
            # Check layer deletion
            is_c_layer_deletion = np.any([len(rl) == 0 for rl in r_to_keep['c']]) 
            is_d_layer_deletion = np.any([len(rl) == 0 for rl in r_to_keep['d']]) 
            is_head_layer_deletion = np.any([is_c_layer_deletion, is_d_layer_deletion])
            
            if is_head_layer_deletion:
                # Restart the algorithm
                if is_c_layer_deletion:
                    r['c'] = [len(rl) for rl in r_to_keep['c'][:-1]]
                    k['c'] = k['c'][:-1]
                if is_d_layer_deletion:
                    r['d'] = [len(rl) for rl in r_to_keep['d'][:-1]]
                    k['d'] = k['d'][:-1]   
                    
                init = dim_reduce_init(pd.DataFrame(y), n_clusters, k, r, nj, var_distrib,\
                                       seed = None)
                
                eta_c, eta_d, H_c, H_d, psi_c, psi_d = dispatch_dgmm_init(init)
                lambda_bin, lambda_ord, lambda_categ = dispatch_gllvm_init(init)
                w_s_c, w_s_d = dispatch_paths_init(init)
                  
                # *_1L standsds for quantities going through all the network (head + tail)
                k_1L, L_1L, L, bar_L, S_1L = nb_comps_and_layers(k)    
                r_1L = {'c': r['c'] + r['t'], 'd': r['d'] + r['t'], 't': r['t']}
                        
                M = M_growth(it_num + 1, r_1L, numobs) 
                
                prev_lik = deepcopy(new_lik)
                it_num = it_num + 1
                print(likelihood)
                
                print('Restarting the algorithm')
                continue
            
            new_Lt = np.sum([len(rl) != 0 for rl in r_to_keep['t']]) #- 1
            
            # If r_l == 0, delete the last l + 1: layers
            new_Lt = np.sum([len(rl) != 0 for rl in r_to_keep['t']]) #- 1
            
            #w_s_t = pst_yCyD.mean(0)
            k_to_keep = k_select(w_s_c, w_s_d, w_s_t, k, new_Lt, clustering_layer, n_clusters)
                        
            is_selection = check_if_selection(r_to_keep, r, k_to_keep, k, L, new_Lt)
            
            assert new_Lt > 0 # > 1 ?
            if n_clusters == 'multi':
                assert new_Lt == L['t']
            
            if is_selection:
                
                # Part to change when update also number of layers on each head 
                nb_deleted_layers_tail = L['t'] - new_Lt
                L['t'] = new_Lt
                L_1L = {keys: values - nb_deleted_layers_tail for keys, values in L_1L.items()}
                
                eta_c, eta_d, H_c, H_d, psi_c, psi_d = dgmm_coeff_selection(eta_c,\
                            H_c, psi_c, eta_d, H_d, psi_d, L, r_to_keep, k_to_keep)
                    
                lambda_bin, lambda_ord, lambda_categ = gllvm_coeff_selection(lambda_bin, lambda_ord,\
                                                               lambda_categ, r, r_to_keep)
                
                w_s_c, w_s_d = path_proba_selection(w_s_c, w_s_d, k, k_to_keep, new_Lt)
                
                k = {h: [len(k_to_keep[h][l]) for l in range(L[h])] for h in ['d', 't']}
                k['c'] = [len(k_to_keep['c'][l]) for l in range(L['c'] + 1)]
                
                r = {h: [len(r_to_keep[h][l]) for l in range(L[h])] for h in ['d', 't']}
                r['c'] = [len(r_to_keep['c'][l]) for l in range(L['c'] + 1)]
                
                k_1L, _, L, bar_L, S_1L = nb_comps_and_layers(k)    
                r_1L = {'c': r['c'] + r['t'], 'd': r['d'] + r['t'], 't': r['t']}
            
                patience = 0
                best_r = deepcopy(r)
                best_k = deepcopy(k)  
                
                #=======================================================
                # Identifiability conditions
                #======================================================= 
                eta_d, H_d, psi_d, AT_d, eta_c, H_c, psi_c, AT_c = network_identifiability(eta_d, \
                                H_d, psi_d, eta_c, H_c, psi_c, w_s_c, w_s_d, w_s_t, bar_L)
                    
            print('New architecture:')
            print('k', k)
            print('r', r)
            print('L', L)
            print('S_1L', S_1L)
            print("w_s_c", len(w_s_c))
            print("w_s_d", len(w_s_d))
        
        M = M_growth(it_num + 1, r_1L, numobs)
        
        prev_lik = deepcopy(new_lik)
        print(likelihood)
        print('Silhouette score:', new_sil)  
        
        it_num = it_num + 1

    out = dict(likelihood = likelihood, classes = classes, \
                   best_r = best_r, best_k = best_k)
    if n_clusters == 'multi':
        out['z'] = z_tail
    else:
        out['z'] = z_tail[clustering_layer]
    return(out)
예제 #3
0
def DDGMM(y, n_clusters, r, k, init, var_distrib, nj, it = 50, \
          eps = 1E-05, maxstep = 100, seed = None, perform_selec = True):
    ''' Fit a Generalized Linear Mixture of Latent Variables Model (GLMLVM)
    
    y (numobs x p ndarray): The observations containing categorical variables
    n_clusters (int): The number of clusters to look for in the data
    r (list): The dimension of latent variables through the first 2 layers
    k (list): The number of components of the latent Gaussian mixture layers
    init (dict): The initialisation parameters for the algorithm
    var_distrib (p 1darray): An array containing the types of the variables in y 
    nj (p 1darray): For binary/count data: The maximum values that the variable can take. 
                    For ordinal data: the number of different existing categories for each variable
    it (int): The maximum number of MCEM iterations of the algorithm
    eps (float): If the likelihood increase by less than eps then the algorithm stops
    maxstep (int): The maximum number of optimisation step for each variable
    seed (int): The random state seed to set (Only for numpy generated data for the moment)
    perform_selec (Bool): Whether to perform architecture selection or not
    ------------------------------------------------------------------------------------------------
    returns (dict): The predicted classes, the likelihood through the EM steps
                    and a continuous representation of the data
    '''

    prev_lik = -1E16
    best_lik = -1E16
    tol = 0.01
    max_patience = 1
    patience = 0

    best_k = deepcopy(k)
    best_r = deepcopy(r)

    best_sil = -1
    new_sil = -1

    # Initialize the parameters
    eta = deepcopy(init['eta'])
    psi = deepcopy(init['psi'])
    lambda_bin = deepcopy(init['lambda_bin'])
    lambda_ord = deepcopy(init['lambda_ord'])
    lambda_categ = deepcopy(init['lambda_categ'])

    H = deepcopy(init['H'])
    w_s = deepcopy(
        init['w_s']
    )  # Probability of path s' through the network for all s' in Omega

    numobs = len(y)
    likelihood = []
    it_num = 0
    ratio = 1000
    np.random.seed = seed

    # Dispatch variables between categories
    y_bin = y[:,
              np.logical_or(var_distrib == 'bernoulli', var_distrib ==
                            'binomial')]
    nj_bin = nj[np.logical_or(var_distrib == 'bernoulli',
                              var_distrib == 'binomial')].astype(int)
    nb_bin = len(nj_bin)

    y_categ = y[:, var_distrib == 'categorical']
    nj_categ = nj[var_distrib == 'categorical'].astype(int)
    nb_categ = len(nj_categ)

    y_ord = y[:, var_distrib == 'ordinal']
    nj_ord = nj[var_distrib == 'ordinal'].astype(int)
    nb_ord = len(nj_ord)

    L = len(k)
    k_aug = k + [1]
    S = np.array([np.prod(k_aug[l:]) for l in range(L + 1)])
    M = M_growth(1, r, numobs)

    assert nb_ord + nb_bin + nb_categ > 0

    # Compute the Gower matrix
    cat_features = np.logical_or(var_distrib == 'categorical',
                                 var_distrib == 'bernoulli')
    dm = gower_matrix(y, cat_features=cat_features)

    while (it_num < it) & ((ratio > eps) | (patience <= max_patience)):
        print(it_num)

        # The clustering layer is the one used to perform the clustering
        # i.e. the layer l such that k[l] == n_clusters
        clustering_layer = np.argmax(np.array(k) == n_clusters)

        #####################################################################################
        ################################# S step ############################################
        #####################################################################################

        #=====================================================================
        # Draw from f(z^{l} | s, Theta) for all s in Omega
        #=====================================================================

        mu_s, sigma_s = compute_path_params(eta, H, psi)
        sigma_s = ensure_psd(sigma_s)
        z_s, zc_s = draw_z_s(mu_s, sigma_s, eta, M)
        '''
        print('mu_s',  np.abs(mu_s[0]).mean())
        print('sigma_s',  np.abs(sigma_s[0]).mean())
        print('z_s0',  np.abs(z_s[0]).mean())
        print('z_s1',  np.abs(z_s[1]).mean(0)[:,0])
        '''

        #========================================================================
        # Draw from f(z^{l+1} | z^{l}, s, Theta) for l >= 1
        #========================================================================

        chsi = compute_chsi(H, psi, mu_s, sigma_s)
        chsi = ensure_psd(chsi)
        rho = compute_rho(eta, H, psi, mu_s, sigma_s, zc_s, chsi)

        # In the following z2 and z1 will denote z^{l+1} and z^{l} respectively
        z2_z1s = draw_z2_z1s(chsi, rho, M, r)

        #=======================================================================
        # Compute the p(y| z1) for all variable categories
        #=======================================================================

        py_zl1 = fy_zl1(lambda_bin, y_bin, nj_bin, lambda_ord, y_ord, nj_ord,
                        lambda_categ, y_categ, nj_categ, z_s[0])

        #========================================================================
        # Draw from p(z1 | y, s) proportional to p(y | z1) * p(z1 | s) for all s
        #========================================================================

        zl1_ys = draw_zl1_ys(z_s, py_zl1, M)

        #####################################################################################
        ################################# E step ############################################
        #####################################################################################

        #=====================================================================
        # Compute conditional probabilities used in the appendix of asta paper
        #=====================================================================

        pzl1_ys, ps_y, p_y = E_step_GLLVM(z_s[0], mu_s[0], sigma_s[0], w_s,
                                          py_zl1)
        #del(py_zl1)

        #=====================================================================
        # Compute p(z^{(l)}| s, y). Equation (5) of the paper
        #=====================================================================

        pz2_z1s = fz2_z1s(t(pzl1_ys, (1, 0, 2)), z2_z1s, chsi, rho, S)
        pz_ys = fz_ys(t(pzl1_ys, (1, 0, 2)), pz2_z1s)

        #=====================================================================
        # Compute MFA expectations
        #=====================================================================

        Ez_ys, E_z1z2T_ys, E_z2z2T_ys, EeeT_ys = \
            E_step_DGMM(zl1_ys, H, z_s, zc_s, z2_z1s, pz_ys, pz2_z1s, S)

        ###########################################################################
        ############################ M step #######################################
        ###########################################################################

        #=======================================================
        # Compute MFA Parameters
        #=======================================================

        w_s = np.mean(ps_y, axis=0)
        eta, H, psi = M_step_DGMM(Ez_ys, E_z1z2T_ys, E_z2z2T_ys, EeeT_ys, ps_y,
                                  H, k)

        #=======================================================
        # Identifiability conditions
        #=======================================================

        # Update eta, H and Psi values
        H = diagonal_cond(H, psi)
        Ez, AT = compute_z_moments(w_s, eta, H, psi)
        eta, H, psi = identifiable_estim_DGMM(eta, H, psi, Ez, AT)

        del (Ez)

        #=======================================================
        # Compute GLLVM Parameters
        #=======================================================

        # We optimize each column separately as it is faster than all column jointly
        # (and more relevant with the independence hypothesis)

        lambda_bin = bin_params_GLLVM(y_bin, nj_bin, lambda_bin, ps_y, pzl1_ys, z_s[0], AT[0],\
                     tol = tol, maxstep = maxstep)

        lambda_ord = ord_params_GLLVM(y_ord, nj_ord, lambda_ord, ps_y, pzl1_ys, z_s[0], AT[0],\
                     tol = tol, maxstep = maxstep)

        lambda_categ = categ_params_GLLVM(y_categ, nj_categ, lambda_categ, ps_y, pzl1_ys, z_s[0], AT[0],\
                     tol = tol, maxstep = maxstep)

        ###########################################################################
        ################## Clustering parameters updating #########################
        ###########################################################################

        new_lik = np.sum(np.log(p_y))
        likelihood.append(new_lik)
        ratio = (new_lik - prev_lik) / abs(prev_lik)
        print(likelihood)

        idx_to_sum = tuple(set(range(1, L + 1)) - set([clustering_layer + 1]))
        psl_y = ps_y.reshape(numobs, *k, order='C').sum(idx_to_sum)

        temp_class = np.argmax(psl_y, axis=1)
        try:
            new_sil = silhouette_score(dm, temp_class, metric='precomputed')
        except ValueError:
            new_sil = -1

        print('Silhouette score:', new_sil)
        if best_sil < new_sil:
            z = (ps_y[..., n_axis] * Ez_ys[clustering_layer]).sum(1)
            best_sil = deepcopy(new_sil)
            classes = deepcopy(temp_class)

            fig = plt.figure(figsize=(8, 8))
            plt.scatter(z[:, 0], z[:, 1])
            plt.show()

        # Refresh the classes only if they provide a better explanation of the data
        if best_lik < new_lik:
            best_lik = deepcopy(prev_lik)

        if prev_lik < new_lik:
            patience = 0
            M = M_growth(it_num + 2, r, numobs)
        else:
            patience += 1

        ###########################################################################
        ######################## Parameter selection  #############################
        ###########################################################################

        is_not_min_specif = not (np.all(np.array(k) == n_clusters)
                                 & np.array_equal(r, [2, 1]))

        if look_for_simpler_network(
                it_num) & perform_selec & is_not_min_specif:
            r_to_keep = r_select(y_bin, y_ord, y_categ, zl1_ys, z2_z1s, w_s)

            # If r_l == 0, delete the last l + 1: layers
            new_L = np.sum([len(rl) != 0 for rl in r_to_keep]) - 1

            k_to_keep = k_select(w_s, k, new_L, clustering_layer)

            is_L_unchanged = L == new_L
            is_r_unchanged = np.all(
                [len(r_to_keep[l]) == r[l] for l in range(new_L + 1)])
            is_k_unchanged = np.all(
                [len(k_to_keep[l]) == k[l] for l in range(new_L)])

            is_selection = not (is_r_unchanged & is_k_unchanged
                                & is_L_unchanged)

            assert new_L > 0

            if is_selection:

                eta = [eta[l][k_to_keep[l]] for l in range(new_L)]
                eta = [eta[l][:, r_to_keep[l]] for l in range(new_L)]

                H = [H[l][k_to_keep[l]] for l in range(new_L)]
                H = [H[l][:, r_to_keep[l]] for l in range(new_L)]
                H = [H[l][:, :, r_to_keep[l + 1]] for l in range(new_L)]

                psi = [psi[l][k_to_keep[l]] for l in range(new_L)]
                psi = [psi[l][:, r_to_keep[l]] for l in range(new_L)]
                psi = [psi[l][:, :, r_to_keep[l]] for l in range(new_L)]

                if nb_bin > 0:
                    # Add the intercept:
                    bin_r_to_keep = np.concatenate([[0],
                                                    np.array(r_to_keep[0]) + 1
                                                    ])
                    lambda_bin = lambda_bin[:, bin_r_to_keep]

                if nb_ord > 0:
                    # Intercept coefficients handling is a little more complicated here
                    lambda_ord_intercept = [
                        lambda_ord_j[:-r[0]] for lambda_ord_j in lambda_ord
                    ]
                    Lambda_ord_var = np.stack(
                        [lambda_ord_j[-r[0]:] for lambda_ord_j in lambda_ord])
                    Lambda_ord_var = Lambda_ord_var[:, r_to_keep[0]]
                    lambda_ord = [np.concatenate([lambda_ord_intercept[j], Lambda_ord_var[j]])\
                                  for j in range(nb_ord)]

                if nb_categ > 0:
                    lambda_categ_intercept = [
                        lambda_categ[j][:, 0] for j in range(nb_categ)
                    ]
                    Lambda_categ_var = [
                        lambda_categ_j[:, -r[0]:]
                        for lambda_categ_j in lambda_categ
                    ]
                    Lambda_categ_var = [
                        lambda_categ_j[:, r_to_keep[0]]
                        for lambda_categ_j in lambda_categ
                    ]

                    lambda_categ = [np.hstack([lambda_categ_intercept[j][..., n_axis], Lambda_categ_var[j]])\
                                   for j in range(nb_categ)]

                w = w_s.reshape(*k, order='C')
                new_k_idx_grid = np.ix_(*k_to_keep[:new_L])

                # If layer deletion, sum the last components of the paths
                if L > new_L:
                    deleted_dims = tuple(range(L)[new_L:])
                    w_s = w[new_k_idx_grid].sum(deleted_dims).flatten(
                        order='C')
                else:
                    w_s = w[new_k_idx_grid].flatten(order='C')

                w_s /= w_s.sum()

                k = [len(k_to_keep[l]) for l in range(new_L)]
                r = [len(r_to_keep[l]) for l in range(new_L + 1)]

                k_aug = k + [1]
                S = np.array([np.prod(k_aug[l:]) for l in range(new_L + 1)])
                L = new_L

                patience = 0
                best_r = deepcopy(r)
                best_k = deepcopy(k)

                # Identifiability conditions
                H = diagonal_cond(H, psi)
                Ez, AT = compute_z_moments(w_s, eta, H, psi)
                eta, H, psi = identifiable_estim_DGMM(eta, H, psi, Ez, AT)

            print('New architecture:')
            print('k', k)
            print('r', r)
            print('L', L)
            print('S', S)
            print("w_s", len(w_s))

        prev_lik = deepcopy(new_lik)
        it_num = it_num + 1

    out = dict(likelihood = likelihood, classes = classes, z = z, \
               best_r = best_r, best_k = best_k)
    return (out)
예제 #4
0
def M1DGMM(y, n_clusters, r, k, init, var_distrib, nj, it = 50, \
          eps = 1E-05, maxstep = 100, seed = None, perform_selec = True,\
              dm =  [], max_patience = 1, use_silhouette = True):# dm small hack to remove 
    
    ''' Fit a Generalized Linear Mixture of Latent Variables Model (GLMLVM)
    
    y (numobs x p ndarray): The observations containing mixed variables
    n_clusters (int): The number of clusters to look for in the data
    r (list): The dimension of latent variables through the first 2 layers
    k (list): The number of components of the latent Gaussian mixture layers
    init (dict): The initialisation parameters for the algorithm
    var_distrib (p 1darray): An array containing the types of the variables in y 
    nj (p 1darray): For binary/count data: The maximum values that the variable can take. 
                    For ordinal data: the number of different existing categories for each variable
    it (int): The maximum number of MCEM iterations of the algorithm
    eps (float): If the likelihood increase by less than eps then the algorithm stops
    maxstep (int): The maximum number of optimisation step for each variable
    seed (int): The random state seed to set (Only for numpy generated data for the moment)
    perform_selec (Bool): Whether to perform architecture selection or not
    use_silhouette (Bool): If True use the silhouette as quality criterion (best for clustering) else use
                            the likelihood (best for data augmentation).
    ------------------------------------------------------------------------------------------------
    returns (dict): The predicted classes, the likelihood through the EM steps
                    and a continuous representation of the data
    '''

    prev_lik = - 1E16
    best_lik = -1E16
    
    best_sil = -1 
    new_sil = -1 
        
    tol = 0.01
    patience = 0
    is_looking_for_better_arch = False
    
    # Initialize the parameters
    eta = deepcopy(init['eta'])
    psi = deepcopy(init['psi'])
    lambda_bin = deepcopy(init['lambda_bin'])
    lambda_ord = deepcopy(init['lambda_ord'])
    lambda_cont = deepcopy(init['lambda_cont'])
    lambda_categ = deepcopy(init['lambda_categ'])

    H = deepcopy(init['H'])
    w_s = deepcopy(init['w_s']) # Probability of path s' through the network for all s' in Omega
   
    numobs = len(y)
    likelihood = []
    silhouette = []
    it_num = 0
    ratio = 1000
    np.random.seed = seed
    out = {} # Store the full output
        
    # Dispatch variables between categories
    y_bin = y[:, np.logical_or(var_distrib == 'bernoulli',var_distrib == 'binomial')]
    nj_bin = nj[np.logical_or(var_distrib == 'bernoulli',var_distrib == 'binomial')].astype(int)
    nb_bin = len(nj_bin)
        
    y_ord = y[:, var_distrib == 'ordinal']    
    nj_ord = nj[var_distrib == 'ordinal'].astype(int)
    nb_ord = len(nj_ord)
    
    y_categ = y[:, var_distrib == 'categorical']
    nj_categ = nj[var_distrib == 'categorical'].astype(int)
    nb_categ = len(nj_categ)    
    
    y_cont = y[:, var_distrib == 'continuous'].astype(float)
    nb_cont = y_cont.shape[1]
    
    # Set y_count standard error to 1
    y_cont = y_cont / y_cont.std(axis = 0, keepdims = True)
    
    L = len(k)
    k_aug = k + [1]
    S = np.array([np.prod(k_aug[l:]) for l in range(L + 1)])    
    M = M_growth(1, r, numobs)
   
    assert nb_bin + nb_ord + nb_cont + nb_categ > 0 
    if nb_bin + nb_ord + nb_cont + nb_categ != len(var_distrib):
        raise ValueError('Some variable types were not understood,\
                         existing types are: continuous, categorical,\
                         ordinal, binomial and bernoulli')

    # Compute the Gower matrix
    if len(dm) == 0:
        cat_features = np.logical_or(var_distrib == 'categorical', var_distrib == 'bernoulli')
        dm = gower_matrix(y, cat_features = cat_features)
    
               
    # Do not stop the iterations if there are some iterations left or if the likelihood is increasing
    # or if we have not reached the maximum patience and if a new architecture was looked for
    # in the previous iteration
    while ((it_num < it) & (ratio > eps) & (patience <= max_patience)) | is_looking_for_better_arch:
        print(it_num)

        # The clustering layer is the one used to perform the clustering 
        # i.e. the layer l such that k[l] == n_clusters
        
        if not(isnumeric(n_clusters)):
            if n_clusters == 'auto':
                clustering_layer = 0
            else:
                raise ValueError('Please enter an int or "auto" for n_clusters')
        else:
            assert (np.array(k) == n_clusters).any()
            clustering_layer = np.argmax(np.array(k) == n_clusters)

        #####################################################################################
        ################################# S step ############################################
        #####################################################################################

        #=====================================================================
        # Draw from f(z^{l} | s, Theta) for all s in Omega
        #=====================================================================  
        
        mu_s, sigma_s = compute_path_params(eta, H, psi)
        sigma_s = ensure_psd(sigma_s)
        z_s, zc_s = draw_z_s(mu_s, sigma_s, eta, M)
         
        #========================================================================
        # Draw from f(z^{l+1} | z^{l}, s, Theta) for l >= 1
        #========================================================================
        
        chsi = compute_chsi(H, psi, mu_s, sigma_s)
        chsi = ensure_psd(chsi)
        rho = compute_rho(eta, H, psi, mu_s, sigma_s, zc_s, chsi)

        # In the following z2 and z1 will denote z^{l+1} and z^{l} respectively
        z2_z1s = draw_z2_z1s(chsi, rho, M, r)
                   
        #=======================================================================
        # Compute the p(y| z1) for all variable categories
        #=======================================================================
        
        py_zl1 = fy_zl1(lambda_bin, y_bin, nj_bin, lambda_ord, y_ord, nj_ord, \
                        lambda_categ, y_categ, nj_categ, y_cont, lambda_cont, z_s[0])
        
        #========================================================================
        # Draw from p(z1 | y, s) proportional to p(y | z1) * p(z1 | s) for all s
        #========================================================================
                
        zl1_ys = draw_zl1_ys(z_s, py_zl1, M)
                
        #####################################################################################
        ################################# E step ############################################
        #####################################################################################
        
        #=====================================================================
        # Compute conditional probabilities used in the appendix of asta paper
        #=====================================================================
        
        pzl1_ys, ps_y, p_y = E_step_GLLVM(z_s[0], mu_s[0], sigma_s[0], w_s, py_zl1)

        #=====================================================================
        # Compute p(z^{(l)}| s, y). Equation (5) of the paper
        #=====================================================================
        
        pz2_z1s = fz2_z1s(t(pzl1_ys, (1, 0, 2)), z2_z1s, chsi, rho, S)
        pz_ys = fz_ys(t(pzl1_ys, (1, 0, 2)), pz2_z1s)
                
        
        #=====================================================================
        # Compute MFA expectations
        #=====================================================================
        
        Ez_ys, E_z1z2T_ys, E_z2z2T_ys, EeeT_ys = \
            E_step_DGMM(zl1_ys, H, z_s, zc_s, z2_z1s, pz_ys, pz2_z1s, S)


        ###########################################################################
        ############################ M step #######################################
        ###########################################################################
             
        #=======================================================
        # Compute MFA Parameters 
        #=======================================================

        w_s = np.mean(ps_y, axis = 0)      
        eta, H, psi = M_step_DGMM(Ez_ys, E_z1z2T_ys, E_z2z2T_ys, EeeT_ys, ps_y, H, k)

        #=======================================================
        # Identifiability conditions
        #======================================================= 

        # Update eta, H and Psi values
        H = diagonal_cond(H, psi)
        Ez, AT = compute_z_moments(w_s, eta, H, psi)
        eta, H, psi = identifiable_estim_DGMM(eta, H, psi, Ez, AT)
        
        del(Ez)
        
        #=======================================================
        # Compute GLLVM Parameters
        #=======================================================
                        
        lambda_bin = bin_params_GLLVM(y_bin, nj_bin, lambda_bin, ps_y, pzl1_ys, z_s[0], AT[0],\
                     tol = tol, maxstep = maxstep)
                 
        lambda_ord = ord_params_GLLVM(y_ord, nj_ord, lambda_ord, ps_y, pzl1_ys, z_s[0], AT[0],\
                     tol = tol, maxstep = maxstep)
            
        lambda_categ = categ_params_GLLVM(y_categ, nj_categ, lambda_categ, ps_y, pzl1_ys, z_s[0], AT[0],\
                     tol = tol, maxstep = maxstep)

        lambda_cont = cont_params_GLLVM(y_cont, lambda_cont, ps_y, pzl1_ys, z_s[0], AT[0],\
                     tol = tol, maxstep = maxstep)

        ###########################################################################
        ################## Clustering parameters updating #########################
        ###########################################################################
          
        new_lik = np.sum(np.log(p_y))
        likelihood.append(new_lik)
        silhouette.append(new_sil)
        ratio = abs((new_lik - prev_lik)/prev_lik)
        
        idx_to_sum = tuple(set(range(1, L + 1)) - set([clustering_layer + 1]))
        psl_y = ps_y.reshape(numobs, *k, order = 'C').sum(idx_to_sum) 

        temp_class = np.argmax(psl_y, axis = 1)
        try:
            new_sil = silhouette_score(dm, temp_class, metric = 'precomputed')
        except ValueError:
            new_sil = -1
           
        # Store the params according to the silhouette or likelihood
        is_better = (best_sil < new_sil) if use_silhouette else (best_lik < new_lik)
            
        if is_better:
            z = (ps_y[..., n_axis] * Ez_ys[clustering_layer]).sum(1)
            best_sil = deepcopy(new_sil)
            classes = deepcopy(temp_class)
            '''
            plt.figure(figsize=(8,8))
            plt.scatter(z[:, 0], z[:, 1], c = classes)
            plt.show()
            '''
            
            # Store the output
            out['classes'] = deepcopy(classes)
            out['best_z'] = deepcopy(z_s[0])
            out['Ez.y'] = z
            out['best_k'] = deepcopy(k)
            out['best_r'] = deepcopy(r)
            
            out['best_w_s'] = deepcopy(w_s)
            out['lambda_bin'] = deepcopy(lambda_bin)
            out['lambda_ord'] = deepcopy(lambda_ord)
            out['lambda_categ'] = deepcopy(lambda_categ)
            out['lambda_cont'] = deepcopy(lambda_cont)

            out['eta'] = deepcopy(eta)            
            out['mu'] = deepcopy(mu_s)
            out['sigma'] = deepcopy(sigma_s)
            
            out['psl_y'] = deepcopy(psl_y)
            out['ps_y'] = deepcopy(ps_y)

            
        # Refresh the classes only if they provide a better explanation of the data
        if best_lik < new_lik:
            best_lik = deepcopy(prev_lik)
                               
        if prev_lik < new_lik:
            patience = 0
            M = M_growth(it_num + 2, r, numobs)
        else:
            patience += 1
                          
        ###########################################################################
        ######################## Parameter selection  #############################
        ###########################################################################
        min_nb_clusters = 2
       
        if isnumeric(n_clusters): # To change when add multi mode
            is_not_min_specif = not(np.all(np.array(k) == n_clusters) & np.array_equal(r, [2,1]))
        else:
            is_not_min_specif = not(np.all(np.array(k) == min_nb_clusters) & np.array_equal(r, [2,1]))
        
        is_looking_for_better_arch = look_for_simpler_network(it_num) & perform_selec & is_not_min_specif
        if is_looking_for_better_arch:
            r_to_keep = r_select(y_bin, y_ord, y_categ, y_cont, zl1_ys, z2_z1s, w_s)
            
            # If r_l == 0, delete the last l + 1: layers
            new_L = np.sum([len(rl) != 0 for rl in r_to_keep]) - 1 
            
            k_to_keep = k_select(w_s, k, new_L, clustering_layer, not(isnumeric(n_clusters)))
    
            is_L_unchanged = (L == new_L)
            is_r_unchanged = np.all([len(r_to_keep[l]) == r[l] for l in range(new_L + 1)])
            is_k_unchanged = np.all([len(k_to_keep[l]) == k[l] for l in range(new_L)])
              
            is_selection = not(is_r_unchanged & is_k_unchanged & is_L_unchanged)
            
            assert new_L > 0
            
            if is_selection:           
                
                eta = [eta[l][k_to_keep[l]] for l in range(new_L)]
                eta = [eta[l][:, r_to_keep[l]] for l in range(new_L)]
                
                H = [H[l][k_to_keep[l]] for l in range(new_L)]
                H = [H[l][:, r_to_keep[l]] for l in range(new_L)]
                H = [H[l][:, :, r_to_keep[l + 1]] for l in range(new_L)]
                
                psi = [psi[l][k_to_keep[l]] for l in range(new_L)]
                psi = [psi[l][:, r_to_keep[l]] for l in range(new_L)]
                psi = [psi[l][:, :, r_to_keep[l]] for l in range(new_L)]
                
                if nb_bin > 0:
                    # Add the intercept:
                    bin_r_to_keep = np.concatenate([[0], np.array(r_to_keep[0]) + 1]) 
                    lambda_bin = lambda_bin[:, bin_r_to_keep]
                 
                if nb_ord > 0:
                    # Intercept coefficients handling is a little more complicated here
                    lambda_ord_intercept = [lambda_ord_j[:-r[0]] for lambda_ord_j in lambda_ord]
                    Lambda_ord_var = np.stack([lambda_ord_j[-r[0]:] for lambda_ord_j in lambda_ord])
                    Lambda_ord_var = Lambda_ord_var[:, r_to_keep[0]]
                    lambda_ord = [np.concatenate([lambda_ord_intercept[j], Lambda_ord_var[j]])\
                                  for j in range(nb_ord)]
    
                # To recheck
                if nb_cont > 0:
                    # Add the intercept:
                    cont_r_to_keep = np.concatenate([[0], np.array(r_to_keep[0]) + 1]) 
                    lambda_cont = lambda_cont[:, cont_r_to_keep]  
                    
                if nb_categ > 0:
                    lambda_categ_intercept = [lambda_categ[j][:, 0]  for j in range(nb_categ)]
                    Lambda_categ_var = [lambda_categ_j[:,-r[0]:] for lambda_categ_j in lambda_categ]
                    Lambda_categ_var = [lambda_categ_j[:, r_to_keep[0]] for lambda_categ_j in lambda_categ]

                    lambda_categ = [np.hstack([lambda_categ_intercept[j][..., n_axis], Lambda_categ_var[j]])\
                                   for j in range(nb_categ)]  

                w = w_s.reshape(*k, order = 'C')
                new_k_idx_grid = np.ix_(*k_to_keep[:new_L])
                
                # If layer deletion, sum the last components of the paths
                if L > new_L: 
                    deleted_dims = tuple(range(L)[new_L:])
                    w_s = w[new_k_idx_grid].sum(deleted_dims).flatten(order = 'C')
                else:
                    w_s = w[new_k_idx_grid].flatten(order = 'C')
    
                w_s /= w_s.sum()
                
                
                # Refresh the classes: TO RECHECK
                #idx_to_sum = tuple(set(range(1, L + 1)) - set([clustering_layer + 1]))
                #ps_y_tmp = ps_y.reshape(numobs, *k, order = 'C').sum(idx_to_sum)
                #np.argmax(ps_y_tmp[:, k_to_keep[0]], axis = 1)

    
                k = [len(k_to_keep[l]) for l in range(new_L)]
                r = [len(r_to_keep[l]) for l in range(new_L + 1)]
                
                k_aug = k + [1]
                S = np.array([np.prod(k_aug[l:]) for l in range(new_L + 1)])    
                L = new_L

                patience = 0
                
                # Identifiability conditions
                H = diagonal_cond(H, psi)
                Ez, AT = compute_z_moments(w_s, eta, H, psi)
                eta, H, psi = identifiable_estim_DGMM(eta, H, psi, Ez, AT)
        
                del(Ez)
                                                
                         
            print('New architecture:')
            print('k', k)
            print('r', r)
            print('L', L)
            print('S',S)
            print("w_s", len(w_s))
            
        prev_lik = deepcopy(new_lik)
        it_num = it_num + 1
        print(likelihood)
        print(silhouette)
        

    out['likelihood'] = likelihood
    out['silhouette'] = silhouette
    
    return(out)