def process_palm_on_top_of_kmeans(kmeans_centroids):
    lst_constraint_sets, lst_constraint_sets_desc = build_constraint_set_smart(
        left_dim=kmeans_centroids.shape[0],
        right_dim=kmeans_centroids.shape[1],
        nb_factors=paraman["--nb-factors"] + 1,
        sparsity_factor=paraman["--sparsity-factor"],
        residual_on_right=paraman["--residual-on-right"],
        fast_unstable_proj=True)

    lst_factors = init_lst_factors(*kmeans_centroids.shape,
                                   paraman["--nb-factors"] + 1)

    eye_norm = np.sqrt(kmeans_centroids.shape[0])

    if paraman["--hierarchical"]:
        _lambda_tmp, op_factors, U_centroids, nb_iter_by_factor, objective_palm = \
            hierarchical_palm4msa(
                arr_X_target=np.eye(kmeans_centroids.shape[0]) @ kmeans_centroids,
                lst_S_init=lst_factors,
                lst_dct_projection_function=lst_constraint_sets,
                f_lambda_init=1. * eye_norm,
                nb_iter=paraman["--nb-iteration-palm"],
                update_right_to_left=True,
                residual_on_right=paraman["--residual-on-right"],
                delta_objective_error_threshold_palm=paraman["--delta-threshold"],
                track_objective_palm=False)
    else:
        _lambda_tmp, op_factors, _, objective_palm, nb_iter_palm = \
            palm4msa(arr_X_target=np.eye(kmeans_centroids.shape[0]) @ kmeans_centroids,
                     lst_S_init=lst_factors,
                     nb_factors=len(lst_factors),
                     lst_projection_functions=lst_constraint_sets[-1]["finetune"],
                     f_lambda_init=1. * eye_norm,
                     nb_iter=paraman["--nb-iteration-palm"],
                     update_right_to_left=True,
                     delta_objective_error_threshold=paraman["--delta-threshold"],
                     track_objective=False)

    log_memory_usage(
        "Memory after palm on top of kmeans in process_palm_on_top_of_kmeans")

    _lambda = _lambda_tmp / eye_norm
    lst_factors_ = op_factors.get_list_of_factors()
    op_centroids = SparseFactors([lst_factors_[1] * _lambda] +
                                 lst_factors_[2:])

    return op_centroids
Beispiel #2
0
def qmeans(X_data: np.ndarray,
           K_nb_cluster: int,
           nb_iter: int,
           nb_factors: int,
           params_palm4msa: dict,
           initialization: np.ndarray,
           hierarchical_inside=False,
           delta_objective_error_threshold=1e-6,
           hierarchical_init=False):
    """

    :param X_data: The data matrix of n examples in dimensions d in shape (n, d).
    :param K_nb_cluster: The number of clusters to look for.
    :param nb_iter: The maximum number of iteration.
    :param nb_factors: The number of factors for the decomposition.
    :param initialization: The initial matrix of centroids not yet factorized.
    :param params_palm4msa: The dictionnary of parameters for the palm4msa algorithm.
    :param hierarchical_inside: Tell the algorithm if the hierarchical version of palm4msa should be used.
    :param delta_objective_error_threshold:
    :param hierarchical_init: Tells if the algorithm should make the initialization of sparse factors with the hierarchical version of palm or not.
    :return:
    """
    assert K_nb_cluster == initialization.shape[0], "The number of cluster {} is not equal to the number of centroids in the initialization {}.".format(K_nb_cluster, initialization.shape[0])

    X_data_norms = get_squared_froebenius_norm_line_wise(X_data)

    nb_examples = X_data.shape[0]

    logger.info("Initializing Qmeans")

    init_lambda = params_palm4msa["init_lambda"]
    nb_iter_palm = params_palm4msa["nb_iter"]
    lst_proj_op_by_fac_step = params_palm4msa["lst_constraint_sets"]
    residual_on_right = params_palm4msa["residual_on_right"]
    delta_objective_error_threshold_inner_palm = params_palm4msa["delta_objective_error_threshold"]
    track_objective_palm = params_palm4msa["track_objective"]

    X_centroids_hat = copy.deepcopy(initialization)

    lst_factors = init_lst_factors(K_nb_cluster, X_centroids_hat.shape[1], nb_factors)

    eye_norm = np.sqrt(K_nb_cluster)

    if hierarchical_inside or hierarchical_init:
        _lambda_tmp, op_factors, U_centroids, objective_palm, array_objective_hierarchical= \
            hierarchical_palm4msa(
                arr_X_target=np.eye(K_nb_cluster) @ X_centroids_hat,
                lst_S_init=lst_factors,
                lst_dct_projection_function=lst_proj_op_by_fac_step,
                f_lambda_init=init_lambda * eye_norm,
                nb_iter=nb_iter_palm,
                update_right_to_left=True,
                residual_on_right=residual_on_right,
                track_objective_palm=track_objective_palm,
                delta_objective_error_threshold_palm=delta_objective_error_threshold_inner_palm,
                return_objective_function=track_objective_palm)
    else:
        _lambda_tmp, op_factors, U_centroids, objective_palm, nb_iter_palm = \
            palm4msa(
                arr_X_target=np.eye(K_nb_cluster) @ X_centroids_hat,
                lst_S_init=lst_factors,
                nb_factors=len(lst_factors),
                lst_projection_functions=lst_proj_op_by_fac_step[-1][
                    "finetune"],
                f_lambda_init=init_lambda * eye_norm,
                nb_iter=nb_iter_palm,
                update_right_to_left=True,
                track_objective=track_objective_palm,
                delta_objective_error_threshold=delta_objective_error_threshold_inner_palm)

    lst_factors = None  # safe assignment for debug

    _lambda = _lambda_tmp / eye_norm

    objective_function = np.ones(nb_iter) * -1
    lst_all_objective_functions_palm = []
    lst_all_objective_functions_palm.append(objective_palm)

    i_iter = 0
    delta_objective_error = np.inf
    while ((i_iter < nb_iter) and (delta_objective_error > delta_objective_error_threshold)):

        logger.info("Iteration Qmeans {}".format(i_iter))

        lst_factors_ = op_factors.get_list_of_factors()
        op_centroids = SparseFactors([lst_factors_[1] * _lambda] + lst_factors_[2:])

        ###########################
        # Cluster assignment step #
        ###########################

        indicator_vector, distances = assign_points_to_clusters(X_data, op_centroids, X_norms=X_data_norms)



        #######################
        # Cluster update step #
        #######################

        # get the number of observation in each cluster
        cluster_names, counts = np.unique(indicator_vector, return_counts=True)
        cluster_names_sorted = np.argsort(cluster_names)

        # Update centroid location using the newly (it happens in the assess_cluster_integrity function)
        # assigned data point classes
        # and check if all clusters still have points
        # and change the object X_centroids_hat in place if some cluster have lost points (biggest cluster)
        counts, cluster_names_sorted = update_clusters_with_integrity_check(X_data,
                                                                            X_data_norms,
                                                                            X_centroids_hat, # in place changes
                                                                            K_nb_cluster,
                                                                            counts,
                                                                            indicator_vector,
                                                                            distances,
                                                                            cluster_names,
                                                                            cluster_names_sorted)

        #################
        # PALM4MSA step #
        #################

        # create the diagonal of the sqrt of those counts
        diag_counts_sqrt_normalized = csr_matrix(
            (np.sqrt(counts[cluster_names_sorted] / nb_examples),
             (np.arange(K_nb_cluster), np.arange(K_nb_cluster))))
        diag_counts_sqrt = np.sqrt(counts[cluster_names_sorted])

        # set it as first factor
        op_factors.set_factor(0, diag_counts_sqrt_normalized)


        if hierarchical_inside:
            _lambda_tmp, op_factors, _, objective_palm, array_objective_hierarchical = \
                hierarchical_palm4msa(
                    arr_X_target=diag_counts_sqrt[:, None,] *  X_centroids_hat,
                    lst_S_init=op_factors.get_list_of_factors(),
                    lst_dct_projection_function=lst_proj_op_by_fac_step,
                    f_lambda_init=_lambda * np.sqrt(nb_examples),
                    nb_iter=nb_iter_palm,
                    update_right_to_left=True,
                    residual_on_right=residual_on_right,
                    return_objective_function=track_objective_palm,
                    track_objective_palm=track_objective_palm,
                    delta_objective_error_threshold_palm=delta_objective_error_threshold_inner_palm)

        else:
            _lambda_tmp, op_factors, _, objective_palm, nb_iter_palm = \
                palm4msa(arr_X_target=diag_counts_sqrt[:, None,] *  X_centroids_hat,
                         lst_S_init=op_factors.get_list_of_factors(),
                         nb_factors=op_factors.n_factors,
                         lst_projection_functions=lst_proj_op_by_fac_step[-1][
                             "finetune"],
                         f_lambda_init=_lambda * np.sqrt(nb_examples),
                         nb_iter=nb_iter_palm,
                         update_right_to_left=True,
                         track_objective=track_objective_palm,
                         delta_objective_error_threshold=delta_objective_error_threshold_inner_palm)

        lst_all_objective_functions_palm.append(objective_palm)

        _lambda = _lambda_tmp / np.sqrt(nb_examples)

        objective_function[i_iter] = compute_objective(X_data, op_centroids, indicator_vector)
        if i_iter >= 1:
            delta_objective_error = np.abs(objective_function[i_iter] - objective_function[i_iter-1]) / objective_function[i_iter-1]

        # todo vérifier que l'erreur absolue est plus petite que le threshold plusieurs fois d'affilee

        i_iter += 1

    lst_factors_ = op_factors.get_list_of_factors()
    op_centroids = SparseFactors([lst_factors_[1] * _lambda] + lst_factors_[2:])

    return objective_function[:i_iter], op_centroids, indicator_vector, lst_all_objective_functions_palm
            sparsity_factor=sparsity_fac,
            residual_on_right=True,
            fast_unstable_proj=False,
            constant_first=False)
        lst_constraints_palm = lst_constraints[-1]["finetune"]

        # construit la matrice cible
        X_target = np.random.rand(dim[0], dim[1])

        # op_factor est en quelque sortes la liste des facteurs sparses
        _lambda, op_factors, _, _, _ = \
            palm4msa(
                arr_X_target=X_target,
                lst_S_init=lst_factors,
                nb_factors=len(lst_factors),
                lst_projection_functions=lst_constraints_palm,
                f_lambda_init=1.,
                nb_iter=nb_iter_palm,
                update_right_to_left=update_right_to_left,
                track_objective=True,
                delta_objective_error_threshold=delta_objective)

        pair["input"] = X_target
        pair["output"] = _lambda * op_factors.compute_product(
            return_array=True)

        results["nfac_{}_in_{}_out_{}".format(n_fac, dim[0], dim[1])] = pair

path_dir = Path(__file__.split(".")[0]) / "examples_jovial"
path_dir.mkdir(parents=True, exist_ok=True)

for name_xp, dct_xp in results.items():
Beispiel #4
0
def qkmeans_minibatch(X_data: np.ndarray,
                      K_nb_cluster: int,
                      nb_iter: int,
                      nb_factors: int,
                      params_palm4msa: dict,
                      initialization: np.ndarray,
                      batch_size: int,
                      hierarchical_inside=False,
                      delta_objective_error_threshold=1e-6,
                      hierarchical_init=False):
    """
    :param X_data: The data matrix of n examples in dimensions d in shape (n, d).
    :param K_nb_cluster: The number of clusters to look for.
    :param nb_iter: The maximum number of iteration.
    :param nb_factors: The number of factors for the decomposition.
    :param initialization: The initial matrix of centroids not yet factorized.
    :param params_palm4msa: The dictionnary of parameters for the palm4msa algorithm.
    :param hierarchical_inside: Tell the algorithm if the hierarchical version of palm4msa should be used.
    :param delta_objective_error_threshold:
    :param hierarchical_init: Tells if the algorithm should make the initialization of sparse factors with the hierarchical version of palm or not.
    :param batch_size:  The size of each batch.
    
    :return:
    """

    assert K_nb_cluster == initialization.shape[0]

    logger.debug("Compute squared froebenius norm of data")
    X_data_norms = get_squared_froebenius_norm_line_wise_batch_by_batch(
        X_data, batch_size)

    nb_examples = X_data.shape[0]
    total_nb_of_minibatch = X_data.shape[0] // batch_size

    X_centroids_hat = copy.deepcopy(initialization)

    # ################################ INIT PALM4MSA ###############################
    logger.info("Initializing QKmeans with PALM algorithm")

    lst_factors = init_lst_factors(K_nb_cluster, X_centroids_hat.shape[1],
                                   nb_factors)
    eye_norm = np.sqrt(K_nb_cluster)

    ##########################
    # GET PARAMS OF PALM4MSA #
    ##########################
    init_lambda = params_palm4msa["init_lambda"]
    nb_iter_palm = params_palm4msa["nb_iter"]
    lst_proj_op_by_fac_step = params_palm4msa["lst_constraint_sets"]
    residual_on_right = params_palm4msa["residual_on_right"]
    delta_objective_error_threshold_inner_palm = params_palm4msa[
        "delta_objective_error_threshold"]
    track_objective_palm = params_palm4msa["track_objective"]

    ####################
    # INIT RUN OF PALM #
    ####################

    if hierarchical_inside or hierarchical_init:
        _lambda_tmp, op_factors, _, objective_palm, array_objective_hierarchical= \
            hierarchical_palm4msa(
                arr_X_target=np.eye(K_nb_cluster) @ X_centroids_hat,
                lst_S_init=lst_factors,
                lst_dct_projection_function=lst_proj_op_by_fac_step,
                f_lambda_init=init_lambda * eye_norm,
                nb_iter=nb_iter_palm,
                update_right_to_left=True,
                residual_on_right=residual_on_right,
                track_objective_palm=track_objective_palm,
                delta_objective_error_threshold_palm=delta_objective_error_threshold_inner_palm,
                return_objective_function=track_objective_palm)
    else:
        _lambda_tmp, op_factors, _, objective_palm, nb_iter_palm = \
            palm4msa(
                arr_X_target=np.eye(K_nb_cluster) @ X_centroids_hat,
                lst_S_init=lst_factors,
                nb_factors=len(lst_factors),
                lst_projection_functions=lst_proj_op_by_fac_step[-1][
                    "finetune"],
                f_lambda_init=init_lambda * eye_norm,
                nb_iter=nb_iter_palm,
                update_right_to_left=True,
                track_objective=track_objective_palm,
                delta_objective_error_threshold=delta_objective_error_threshold_inner_palm)

    # ################################################################

    lst_factors = None  # safe assignment for debug

    _lambda = _lambda_tmp / eye_norm

    objective_function = np.ones(nb_iter) * -1
    lst_all_objective_functions_palm = []
    lst_all_objective_functions_palm.append(objective_palm)

    i_iter = 0
    delta_objective_error = np.inf
    while ((i_iter < nb_iter)
           and (delta_objective_error > delta_objective_error_threshold)):
        logger.info("Iteration number {}/{}".format(i_iter, nb_iter))

        # Re-init palm factors for iteration
        lst_factors_ = op_factors.get_list_of_factors()
        op_centroids = SparseFactors([lst_factors_[1] * _lambda] +
                                     lst_factors_[2:])

        # Prepare next epoch
        full_count_vector = np.zeros(K_nb_cluster, dtype=int)
        full_indicator_vector = np.zeros(X_data.shape[0], dtype=int)

        X_centroids_hat = np.zeros_like(X_centroids_hat)

        for i_minibatch, example_batch_indexes in enumerate(
                DataGenerator(X_data,
                              batch_size=batch_size,
                              return_indexes=True)):
            logger.info(
                "Minibatch number {}/{}; Iteration number {}/{}".format(
                    i_minibatch, total_nb_of_minibatch, i_iter, nb_iter))
            example_batch = X_data[example_batch_indexes]
            example_batch_norms = X_data_norms[example_batch_indexes]

            ##########################
            # Update centroid oracle #
            ##########################

            indicator_vector, distances = assign_points_to_clusters(
                example_batch, op_centroids, X_norms=example_batch_norms)
            full_indicator_vector[example_batch_indexes] = indicator_vector

            cluster_names, counts = np.unique(indicator_vector,
                                              return_counts=True)
            count_vector = np.zeros(K_nb_cluster)
            count_vector[cluster_names] = counts

            full_count_vector = update_clusters(example_batch, X_centroids_hat,
                                                K_nb_cluster,
                                                full_count_vector,
                                                count_vector, indicator_vector)

        objective_function[i_iter] = compute_objective_by_batch(
            X_data, op_centroids, full_indicator_vector, batch_size)

        # inplace modification of X_centrois_hat and full_count_vector and full_indicator_vector
        check_cluster_integrity(X_data, X_centroids_hat, K_nb_cluster,
                                full_count_vector, full_indicator_vector)

        #########################
        # Do palm for iteration #
        #########################

        # create the diagonal of the sqrt of those counts
        diag_counts_sqrt_normalized = csr_matrix(
            (np.sqrt(full_count_vector / nb_examples),
             (np.arange(K_nb_cluster), np.arange(K_nb_cluster))))
        diag_counts_sqrt = np.sqrt(full_count_vector)

        # set it as first factor
        op_factors.set_factor(0, diag_counts_sqrt_normalized)

        if hierarchical_inside:
            _lambda_tmp, op_factors, _, objective_palm, array_objective_hierarchical = \
                hierarchical_palm4msa(
                    arr_X_target=diag_counts_sqrt[:, None,] *  X_centroids_hat,
                    lst_S_init=op_factors.get_list_of_factors(),
                    lst_dct_projection_function=lst_proj_op_by_fac_step,
                    f_lambda_init=_lambda * np.sqrt(nb_examples),
                    nb_iter=nb_iter_palm,
                    update_right_to_left=True,
                    residual_on_right=residual_on_right,
                    return_objective_function=track_objective_palm,
                    track_objective_palm=track_objective_palm,
                    delta_objective_error_threshold_palm=delta_objective_error_threshold_inner_palm)

        else:
            _lambda_tmp, op_factors, _, objective_palm, nb_iter_palm = \
                palm4msa(arr_X_target=diag_counts_sqrt[:, None,] *  X_centroids_hat,
                         lst_S_init=op_factors.get_list_of_factors(),
                         nb_factors=op_factors.n_factors,
                         lst_projection_functions=lst_proj_op_by_fac_step[-1][
                             "finetune"],
                         f_lambda_init=_lambda * np.sqrt(nb_examples),
                         nb_iter=nb_iter_palm,
                         update_right_to_left=True,
                         track_objective=track_objective_palm,
                         delta_objective_error_threshold=delta_objective_error_threshold_inner_palm)

        _lambda = _lambda_tmp / np.sqrt(nb_examples)

        ############################

        lst_all_objective_functions_palm.append(objective_palm)

        if i_iter >= 1:
            delta_objective_error = np.abs(objective_function[i_iter] -
                                           objective_function[i_iter - 1]
                                           ) / objective_function[i_iter - 1]

        # todo vérifier que l'erreur absolue est plus petite que le threshold plusieurs fois d'affilee

        i_iter += 1

    op_centroids = SparseFactors([lst_factors_[1] * _lambda] +
                                 lst_factors_[2:])

    return objective_function[:
                              i_iter], op_centroids, full_indicator_vector, lst_all_objective_functions_palm