def process_palm_on_top_of_kmeans(kmeans_centroids): lst_constraint_sets, lst_constraint_sets_desc = build_constraint_set_smart( left_dim=kmeans_centroids.shape[0], right_dim=kmeans_centroids.shape[1], nb_factors=paraman["--nb-factors"] + 1, sparsity_factor=paraman["--sparsity-factor"], residual_on_right=paraman["--residual-on-right"], fast_unstable_proj=True) lst_factors = init_lst_factors(*kmeans_centroids.shape, paraman["--nb-factors"] + 1) eye_norm = np.sqrt(kmeans_centroids.shape[0]) if paraman["--hierarchical"]: _lambda_tmp, op_factors, U_centroids, nb_iter_by_factor, objective_palm = \ hierarchical_palm4msa( arr_X_target=np.eye(kmeans_centroids.shape[0]) @ kmeans_centroids, lst_S_init=lst_factors, lst_dct_projection_function=lst_constraint_sets, f_lambda_init=1. * eye_norm, nb_iter=paraman["--nb-iteration-palm"], update_right_to_left=True, residual_on_right=paraman["--residual-on-right"], delta_objective_error_threshold_palm=paraman["--delta-threshold"], track_objective_palm=False) else: _lambda_tmp, op_factors, _, objective_palm, nb_iter_palm = \ palm4msa(arr_X_target=np.eye(kmeans_centroids.shape[0]) @ kmeans_centroids, lst_S_init=lst_factors, nb_factors=len(lst_factors), lst_projection_functions=lst_constraint_sets[-1]["finetune"], f_lambda_init=1. * eye_norm, nb_iter=paraman["--nb-iteration-palm"], update_right_to_left=True, delta_objective_error_threshold=paraman["--delta-threshold"], track_objective=False) log_memory_usage( "Memory after palm on top of kmeans in process_palm_on_top_of_kmeans") _lambda = _lambda_tmp / eye_norm lst_factors_ = op_factors.get_list_of_factors() op_centroids = SparseFactors([lst_factors_[1] * _lambda] + lst_factors_[2:]) return op_centroids
def qmeans(X_data: np.ndarray, K_nb_cluster: int, nb_iter: int, nb_factors: int, params_palm4msa: dict, initialization: np.ndarray, hierarchical_inside=False, delta_objective_error_threshold=1e-6, hierarchical_init=False): """ :param X_data: The data matrix of n examples in dimensions d in shape (n, d). :param K_nb_cluster: The number of clusters to look for. :param nb_iter: The maximum number of iteration. :param nb_factors: The number of factors for the decomposition. :param initialization: The initial matrix of centroids not yet factorized. :param params_palm4msa: The dictionnary of parameters for the palm4msa algorithm. :param hierarchical_inside: Tell the algorithm if the hierarchical version of palm4msa should be used. :param delta_objective_error_threshold: :param hierarchical_init: Tells if the algorithm should make the initialization of sparse factors with the hierarchical version of palm or not. :return: """ assert K_nb_cluster == initialization.shape[0], "The number of cluster {} is not equal to the number of centroids in the initialization {}.".format(K_nb_cluster, initialization.shape[0]) X_data_norms = get_squared_froebenius_norm_line_wise(X_data) nb_examples = X_data.shape[0] logger.info("Initializing Qmeans") init_lambda = params_palm4msa["init_lambda"] nb_iter_palm = params_palm4msa["nb_iter"] lst_proj_op_by_fac_step = params_palm4msa["lst_constraint_sets"] residual_on_right = params_palm4msa["residual_on_right"] delta_objective_error_threshold_inner_palm = params_palm4msa["delta_objective_error_threshold"] track_objective_palm = params_palm4msa["track_objective"] X_centroids_hat = copy.deepcopy(initialization) lst_factors = init_lst_factors(K_nb_cluster, X_centroids_hat.shape[1], nb_factors) eye_norm = np.sqrt(K_nb_cluster) if hierarchical_inside or hierarchical_init: _lambda_tmp, op_factors, U_centroids, objective_palm, array_objective_hierarchical= \ hierarchical_palm4msa( arr_X_target=np.eye(K_nb_cluster) @ X_centroids_hat, lst_S_init=lst_factors, lst_dct_projection_function=lst_proj_op_by_fac_step, f_lambda_init=init_lambda * eye_norm, nb_iter=nb_iter_palm, update_right_to_left=True, residual_on_right=residual_on_right, track_objective_palm=track_objective_palm, delta_objective_error_threshold_palm=delta_objective_error_threshold_inner_palm, return_objective_function=track_objective_palm) else: _lambda_tmp, op_factors, U_centroids, objective_palm, nb_iter_palm = \ palm4msa( arr_X_target=np.eye(K_nb_cluster) @ X_centroids_hat, lst_S_init=lst_factors, nb_factors=len(lst_factors), lst_projection_functions=lst_proj_op_by_fac_step[-1][ "finetune"], f_lambda_init=init_lambda * eye_norm, nb_iter=nb_iter_palm, update_right_to_left=True, track_objective=track_objective_palm, delta_objective_error_threshold=delta_objective_error_threshold_inner_palm) lst_factors = None # safe assignment for debug _lambda = _lambda_tmp / eye_norm objective_function = np.ones(nb_iter) * -1 lst_all_objective_functions_palm = [] lst_all_objective_functions_palm.append(objective_palm) i_iter = 0 delta_objective_error = np.inf while ((i_iter < nb_iter) and (delta_objective_error > delta_objective_error_threshold)): logger.info("Iteration Qmeans {}".format(i_iter)) lst_factors_ = op_factors.get_list_of_factors() op_centroids = SparseFactors([lst_factors_[1] * _lambda] + lst_factors_[2:]) ########################### # Cluster assignment step # ########################### indicator_vector, distances = assign_points_to_clusters(X_data, op_centroids, X_norms=X_data_norms) ####################### # Cluster update step # ####################### # get the number of observation in each cluster cluster_names, counts = np.unique(indicator_vector, return_counts=True) cluster_names_sorted = np.argsort(cluster_names) # Update centroid location using the newly (it happens in the assess_cluster_integrity function) # assigned data point classes # and check if all clusters still have points # and change the object X_centroids_hat in place if some cluster have lost points (biggest cluster) counts, cluster_names_sorted = update_clusters_with_integrity_check(X_data, X_data_norms, X_centroids_hat, # in place changes K_nb_cluster, counts, indicator_vector, distances, cluster_names, cluster_names_sorted) ################# # PALM4MSA step # ################# # create the diagonal of the sqrt of those counts diag_counts_sqrt_normalized = csr_matrix( (np.sqrt(counts[cluster_names_sorted] / nb_examples), (np.arange(K_nb_cluster), np.arange(K_nb_cluster)))) diag_counts_sqrt = np.sqrt(counts[cluster_names_sorted]) # set it as first factor op_factors.set_factor(0, diag_counts_sqrt_normalized) if hierarchical_inside: _lambda_tmp, op_factors, _, objective_palm, array_objective_hierarchical = \ hierarchical_palm4msa( arr_X_target=diag_counts_sqrt[:, None,] * X_centroids_hat, lst_S_init=op_factors.get_list_of_factors(), lst_dct_projection_function=lst_proj_op_by_fac_step, f_lambda_init=_lambda * np.sqrt(nb_examples), nb_iter=nb_iter_palm, update_right_to_left=True, residual_on_right=residual_on_right, return_objective_function=track_objective_palm, track_objective_palm=track_objective_palm, delta_objective_error_threshold_palm=delta_objective_error_threshold_inner_palm) else: _lambda_tmp, op_factors, _, objective_palm, nb_iter_palm = \ palm4msa(arr_X_target=diag_counts_sqrt[:, None,] * X_centroids_hat, lst_S_init=op_factors.get_list_of_factors(), nb_factors=op_factors.n_factors, lst_projection_functions=lst_proj_op_by_fac_step[-1][ "finetune"], f_lambda_init=_lambda * np.sqrt(nb_examples), nb_iter=nb_iter_palm, update_right_to_left=True, track_objective=track_objective_palm, delta_objective_error_threshold=delta_objective_error_threshold_inner_palm) lst_all_objective_functions_palm.append(objective_palm) _lambda = _lambda_tmp / np.sqrt(nb_examples) objective_function[i_iter] = compute_objective(X_data, op_centroids, indicator_vector) if i_iter >= 1: delta_objective_error = np.abs(objective_function[i_iter] - objective_function[i_iter-1]) / objective_function[i_iter-1] # todo vérifier que l'erreur absolue est plus petite que le threshold plusieurs fois d'affilee i_iter += 1 lst_factors_ = op_factors.get_list_of_factors() op_centroids = SparseFactors([lst_factors_[1] * _lambda] + lst_factors_[2:]) return objective_function[:i_iter], op_centroids, indicator_vector, lst_all_objective_functions_palm
sparsity_factor=sparsity_fac, residual_on_right=True, fast_unstable_proj=False, constant_first=False) lst_constraints_palm = lst_constraints[-1]["finetune"] # construit la matrice cible X_target = np.random.rand(dim[0], dim[1]) # op_factor est en quelque sortes la liste des facteurs sparses _lambda, op_factors, _, _, _ = \ palm4msa( arr_X_target=X_target, lst_S_init=lst_factors, nb_factors=len(lst_factors), lst_projection_functions=lst_constraints_palm, f_lambda_init=1., nb_iter=nb_iter_palm, update_right_to_left=update_right_to_left, track_objective=True, delta_objective_error_threshold=delta_objective) pair["input"] = X_target pair["output"] = _lambda * op_factors.compute_product( return_array=True) results["nfac_{}_in_{}_out_{}".format(n_fac, dim[0], dim[1])] = pair path_dir = Path(__file__.split(".")[0]) / "examples_jovial" path_dir.mkdir(parents=True, exist_ok=True) for name_xp, dct_xp in results.items():
def qkmeans_minibatch(X_data: np.ndarray, K_nb_cluster: int, nb_iter: int, nb_factors: int, params_palm4msa: dict, initialization: np.ndarray, batch_size: int, hierarchical_inside=False, delta_objective_error_threshold=1e-6, hierarchical_init=False): """ :param X_data: The data matrix of n examples in dimensions d in shape (n, d). :param K_nb_cluster: The number of clusters to look for. :param nb_iter: The maximum number of iteration. :param nb_factors: The number of factors for the decomposition. :param initialization: The initial matrix of centroids not yet factorized. :param params_palm4msa: The dictionnary of parameters for the palm4msa algorithm. :param hierarchical_inside: Tell the algorithm if the hierarchical version of palm4msa should be used. :param delta_objective_error_threshold: :param hierarchical_init: Tells if the algorithm should make the initialization of sparse factors with the hierarchical version of palm or not. :param batch_size: The size of each batch. :return: """ assert K_nb_cluster == initialization.shape[0] logger.debug("Compute squared froebenius norm of data") X_data_norms = get_squared_froebenius_norm_line_wise_batch_by_batch( X_data, batch_size) nb_examples = X_data.shape[0] total_nb_of_minibatch = X_data.shape[0] // batch_size X_centroids_hat = copy.deepcopy(initialization) # ################################ INIT PALM4MSA ############################### logger.info("Initializing QKmeans with PALM algorithm") lst_factors = init_lst_factors(K_nb_cluster, X_centroids_hat.shape[1], nb_factors) eye_norm = np.sqrt(K_nb_cluster) ########################## # GET PARAMS OF PALM4MSA # ########################## init_lambda = params_palm4msa["init_lambda"] nb_iter_palm = params_palm4msa["nb_iter"] lst_proj_op_by_fac_step = params_palm4msa["lst_constraint_sets"] residual_on_right = params_palm4msa["residual_on_right"] delta_objective_error_threshold_inner_palm = params_palm4msa[ "delta_objective_error_threshold"] track_objective_palm = params_palm4msa["track_objective"] #################### # INIT RUN OF PALM # #################### if hierarchical_inside or hierarchical_init: _lambda_tmp, op_factors, _, objective_palm, array_objective_hierarchical= \ hierarchical_palm4msa( arr_X_target=np.eye(K_nb_cluster) @ X_centroids_hat, lst_S_init=lst_factors, lst_dct_projection_function=lst_proj_op_by_fac_step, f_lambda_init=init_lambda * eye_norm, nb_iter=nb_iter_palm, update_right_to_left=True, residual_on_right=residual_on_right, track_objective_palm=track_objective_palm, delta_objective_error_threshold_palm=delta_objective_error_threshold_inner_palm, return_objective_function=track_objective_palm) else: _lambda_tmp, op_factors, _, objective_palm, nb_iter_palm = \ palm4msa( arr_X_target=np.eye(K_nb_cluster) @ X_centroids_hat, lst_S_init=lst_factors, nb_factors=len(lst_factors), lst_projection_functions=lst_proj_op_by_fac_step[-1][ "finetune"], f_lambda_init=init_lambda * eye_norm, nb_iter=nb_iter_palm, update_right_to_left=True, track_objective=track_objective_palm, delta_objective_error_threshold=delta_objective_error_threshold_inner_palm) # ################################################################ lst_factors = None # safe assignment for debug _lambda = _lambda_tmp / eye_norm objective_function = np.ones(nb_iter) * -1 lst_all_objective_functions_palm = [] lst_all_objective_functions_palm.append(objective_palm) i_iter = 0 delta_objective_error = np.inf while ((i_iter < nb_iter) and (delta_objective_error > delta_objective_error_threshold)): logger.info("Iteration number {}/{}".format(i_iter, nb_iter)) # Re-init palm factors for iteration lst_factors_ = op_factors.get_list_of_factors() op_centroids = SparseFactors([lst_factors_[1] * _lambda] + lst_factors_[2:]) # Prepare next epoch full_count_vector = np.zeros(K_nb_cluster, dtype=int) full_indicator_vector = np.zeros(X_data.shape[0], dtype=int) X_centroids_hat = np.zeros_like(X_centroids_hat) for i_minibatch, example_batch_indexes in enumerate( DataGenerator(X_data, batch_size=batch_size, return_indexes=True)): logger.info( "Minibatch number {}/{}; Iteration number {}/{}".format( i_minibatch, total_nb_of_minibatch, i_iter, nb_iter)) example_batch = X_data[example_batch_indexes] example_batch_norms = X_data_norms[example_batch_indexes] ########################## # Update centroid oracle # ########################## indicator_vector, distances = assign_points_to_clusters( example_batch, op_centroids, X_norms=example_batch_norms) full_indicator_vector[example_batch_indexes] = indicator_vector cluster_names, counts = np.unique(indicator_vector, return_counts=True) count_vector = np.zeros(K_nb_cluster) count_vector[cluster_names] = counts full_count_vector = update_clusters(example_batch, X_centroids_hat, K_nb_cluster, full_count_vector, count_vector, indicator_vector) objective_function[i_iter] = compute_objective_by_batch( X_data, op_centroids, full_indicator_vector, batch_size) # inplace modification of X_centrois_hat and full_count_vector and full_indicator_vector check_cluster_integrity(X_data, X_centroids_hat, K_nb_cluster, full_count_vector, full_indicator_vector) ######################### # Do palm for iteration # ######################### # create the diagonal of the sqrt of those counts diag_counts_sqrt_normalized = csr_matrix( (np.sqrt(full_count_vector / nb_examples), (np.arange(K_nb_cluster), np.arange(K_nb_cluster)))) diag_counts_sqrt = np.sqrt(full_count_vector) # set it as first factor op_factors.set_factor(0, diag_counts_sqrt_normalized) if hierarchical_inside: _lambda_tmp, op_factors, _, objective_palm, array_objective_hierarchical = \ hierarchical_palm4msa( arr_X_target=diag_counts_sqrt[:, None,] * X_centroids_hat, lst_S_init=op_factors.get_list_of_factors(), lst_dct_projection_function=lst_proj_op_by_fac_step, f_lambda_init=_lambda * np.sqrt(nb_examples), nb_iter=nb_iter_palm, update_right_to_left=True, residual_on_right=residual_on_right, return_objective_function=track_objective_palm, track_objective_palm=track_objective_palm, delta_objective_error_threshold_palm=delta_objective_error_threshold_inner_palm) else: _lambda_tmp, op_factors, _, objective_palm, nb_iter_palm = \ palm4msa(arr_X_target=diag_counts_sqrt[:, None,] * X_centroids_hat, lst_S_init=op_factors.get_list_of_factors(), nb_factors=op_factors.n_factors, lst_projection_functions=lst_proj_op_by_fac_step[-1][ "finetune"], f_lambda_init=_lambda * np.sqrt(nb_examples), nb_iter=nb_iter_palm, update_right_to_left=True, track_objective=track_objective_palm, delta_objective_error_threshold=delta_objective_error_threshold_inner_palm) _lambda = _lambda_tmp / np.sqrt(nb_examples) ############################ lst_all_objective_functions_palm.append(objective_palm) if i_iter >= 1: delta_objective_error = np.abs(objective_function[i_iter] - objective_function[i_iter - 1] ) / objective_function[i_iter - 1] # todo vérifier que l'erreur absolue est plus petite que le threshold plusieurs fois d'affilee i_iter += 1 op_centroids = SparseFactors([lst_factors_[1] * _lambda] + lst_factors_[2:]) return objective_function[: i_iter], op_centroids, full_indicator_vector, lst_all_objective_functions_palm