Example #1
0
def special_rbf_kernel(X, Y, gamma, norm_X, norm_Y):
    """
    Rbf kernel expressed under the form f(x)f(u)f(xy^T)

    Can handle X and Y as Sparse Factors.

    :param X: n x d matrix
    :param Y: n x d matrix
    :return:
    """
    assert len(X.shape) == len(Y.shape) == 2

    if norm_X is None:
        norm_X = get_squared_froebenius_norm_line_wise(X)
    if norm_Y is None:
        norm_Y = get_squared_froebenius_norm_line_wise(Y)

    def f(norm_mat):
        return np.exp(-gamma * norm_mat)

    def g(scal):
        return np.exp(2 * gamma * scal)

    if isinstance(X, SparseFactors) and isinstance(Y, SparseFactors):
        # xyt = SparseFactors(X.get_list_of_factors() + Y.transpose().get_list_of_factors()).compute_product(return_array=True)
        S = SparseFactors(
            lst_factors=X.get_list_of_factors() + Y.get_list_of_factors_H(),
            lst_factors_H=X.get_list_of_factors_H() + Y.get_list_of_factors())
        xyt = S.compute_product(return_array=True)
    else:
        xyt = X @ Y.transpose()

    return f(norm_X).reshape(-1, 1) * g(xyt) * f(norm_Y).reshape(1, -1)
Example #2
0
def special_rbf_kernel(X, Y, gamma, norm_X=None, norm_Y=None, exp_outside=True):
    """
    Rbf kernel expressed under the form f(x)f(u)f(xy^T)

    Can handle X and Y as Sparse Factors.

    :param X: n x d matrix
    :param Y: n x d matrix
    :param gamma:
    :param norm_X: nx1 matrix
    :param norm_Y: 1xn matrix
    :param exp_outside: Tells if the exponential should be computed just once. Numerical instability may arise if False.
    :return:
    """
    assert len(X.shape) == len(Y.shape) == 2

    if norm_X is None:
        norm_X = row_norms(X, squared=True)[:, np.newaxis]
    else:
        norm_X = check_array(norm_X)
        assert norm_X.shape[0] == X.shape[0], "nb line in X and norm X should be the same"
        assert norm_X.shape[1] == 1, "norm X should be 1 dimensional array"


    if norm_Y is None:
        norm_Y = row_norms(Y, squared=True)[np.newaxis, :]
    else:
        norm_Y = check_array(norm_Y)
        assert norm_Y.shape[1] == Y.shape[0], "nb line in Y and norm Y should be the same"
        assert norm_Y.shape[0] == 1, "norm Y should be 1 dimensional array"

    def f(norm_mat):
        return np.exp(-gamma * norm_mat)

    def g(scal):
        return np.exp(2 * gamma * scal)

    if isinstance(X, SparseFactors) and isinstance(Y, SparseFactors):
        # xyt = SparseFactors(X.get_list_of_factors() + Y.transpose().get_list_of_factors()).compute_product(return_array=True)
        S = SparseFactors(lst_factors=X.get_list_of_factors() + Y.get_list_of_factors_H(), lst_factors_H=X.get_list_of_factors_H() + Y.get_list_of_factors())
        xyt = S.compute_product(return_array=True)
    elif not isinstance(X, SparseFactors) and isinstance(Y, SparseFactors):
        xyt = (Y @ X.transpose()).transpose()
    else:
        xyt = X @ Y.transpose()

    if not exp_outside:
        return f(norm_X) * g(xyt) * f(norm_Y)
    else:
        distance = -2 * xyt
        distance += norm_X
        distance += norm_Y
        # distance = norm_X + norm_Y - (2 * xyt)
        np.maximum(distance, 0, out=distance)
        if X is Y:
            np.fill_diagonal(distance, 0)

        in_exp = -gamma * distance
        return np.exp(in_exp)
Example #3
0
def process_palm_on_top_of_kmeans(kmeans_centroids):
    lst_constraint_sets, lst_constraint_sets_desc = build_constraint_set_smart(
        left_dim=kmeans_centroids.shape[0],
        right_dim=kmeans_centroids.shape[1],
        nb_factors=paraman["--nb-factors"] + 1,
        sparsity_factor=paraman["--sparsity-factor"],
        residual_on_right=paraman["--residual-on-right"])

    lst_factors = init_lst_factors(*kmeans_centroids.shape,
                                   paraman["--nb-factors"] + 1)

    eye_norm = np.sqrt(kmeans_centroids.shape[0])

    _lambda_tmp, op_factors, U_centroids, nb_iter_by_factor, objective_palm = \
        hierarchical_palm4msa(
            arr_X_target=np.eye(kmeans_centroids.shape[0]) @ kmeans_centroids,
            lst_S_init=lst_factors,
            lst_dct_projection_function=lst_constraint_sets,
            f_lambda_init=1. * eye_norm,
            nb_iter=paraman["--nb-iteration-palm"],
            update_right_to_left=True,
            residual_on_right=paraman["--residual-on-right"],
            graphical_display=False)

    _lambda = _lambda_tmp / eye_norm
    lst_factors_ = op_factors.get_list_of_factors()
    op_centroids = SparseFactors([lst_factors_[1] * _lambda] +
                                 lst_factors_[2:])

    return op_centroids
def process_palm_on_top_of_kmeans(kmeans_centroids):
    lst_constraint_sets, lst_constraint_sets_desc = build_constraint_set_smart(
        left_dim=kmeans_centroids.shape[0],
        right_dim=kmeans_centroids.shape[1],
        nb_factors=paraman["--nb-factors"] + 1,
        sparsity_factor=paraman["--sparsity-factor"],
        residual_on_right=paraman["--residual-on-right"],
        fast_unstable_proj=True)

    lst_factors = init_lst_factors(*kmeans_centroids.shape,
                                   paraman["--nb-factors"] + 1)

    eye_norm = np.sqrt(kmeans_centroids.shape[0])

    if paraman["--hierarchical"]:
        _lambda_tmp, op_factors, U_centroids, nb_iter_by_factor, objective_palm = \
            hierarchical_palm4msa(
                arr_X_target=np.eye(kmeans_centroids.shape[0]) @ kmeans_centroids,
                lst_S_init=lst_factors,
                lst_dct_projection_function=lst_constraint_sets,
                f_lambda_init=1. * eye_norm,
                nb_iter=paraman["--nb-iteration-palm"],
                update_right_to_left=True,
                residual_on_right=paraman["--residual-on-right"],
                delta_objective_error_threshold_palm=paraman["--delta-threshold"],
                track_objective_palm=False)
    else:
        _lambda_tmp, op_factors, _, objective_palm, nb_iter_palm = \
            palm4msa(arr_X_target=np.eye(kmeans_centroids.shape[0]) @ kmeans_centroids,
                     lst_S_init=lst_factors,
                     nb_factors=len(lst_factors),
                     lst_projection_functions=lst_constraint_sets[-1]["finetune"],
                     f_lambda_init=1. * eye_norm,
                     nb_iter=paraman["--nb-iteration-palm"],
                     update_right_to_left=True,
                     delta_objective_error_threshold=paraman["--delta-threshold"],
                     track_objective=False)

    log_memory_usage(
        "Memory after palm on top of kmeans in process_palm_on_top_of_kmeans")

    _lambda = _lambda_tmp / eye_norm
    lst_factors_ = op_factors.get_list_of_factors()
    op_centroids = SparseFactors([lst_factors_[1] * _lambda] +
                                 lst_factors_[2:])

    return op_centroids
Example #5
0
def qmeans(X_data: np.ndarray,
           K_nb_cluster: int,
           nb_iter: int,
           nb_factors: int,
           params_palm4msa: dict,
           initialization: np.ndarray,
           hierarchical_inside=False,
           delta_objective_error_threshold=1e-6,
           hierarchical_init=False):
    """

    :param X_data: The data matrix of n examples in dimensions d in shape (n, d).
    :param K_nb_cluster: The number of clusters to look for.
    :param nb_iter: The maximum number of iteration.
    :param nb_factors: The number of factors for the decomposition.
    :param initialization: The initial matrix of centroids not yet factorized.
    :param params_palm4msa: The dictionnary of parameters for the palm4msa algorithm.
    :param hierarchical_inside: Tell the algorithm if the hierarchical version of palm4msa should be used.
    :param delta_objective_error_threshold:
    :param hierarchical_init: Tells if the algorithm should make the initialization of sparse factors with the hierarchical version of palm or not.
    :return:
    """
    assert K_nb_cluster == initialization.shape[0], "The number of cluster {} is not equal to the number of centroids in the initialization {}.".format(K_nb_cluster, initialization.shape[0])

    X_data_norms = get_squared_froebenius_norm_line_wise(X_data)

    nb_examples = X_data.shape[0]

    logger.info("Initializing Qmeans")

    init_lambda = params_palm4msa["init_lambda"]
    nb_iter_palm = params_palm4msa["nb_iter"]
    lst_proj_op_by_fac_step = params_palm4msa["lst_constraint_sets"]
    residual_on_right = params_palm4msa["residual_on_right"]
    delta_objective_error_threshold_inner_palm = params_palm4msa["delta_objective_error_threshold"]
    track_objective_palm = params_palm4msa["track_objective"]

    X_centroids_hat = copy.deepcopy(initialization)

    lst_factors = init_lst_factors(K_nb_cluster, X_centroids_hat.shape[1], nb_factors)

    eye_norm = np.sqrt(K_nb_cluster)

    if hierarchical_inside or hierarchical_init:
        _lambda_tmp, op_factors, U_centroids, objective_palm, array_objective_hierarchical= \
            hierarchical_palm4msa(
                arr_X_target=np.eye(K_nb_cluster) @ X_centroids_hat,
                lst_S_init=lst_factors,
                lst_dct_projection_function=lst_proj_op_by_fac_step,
                f_lambda_init=init_lambda * eye_norm,
                nb_iter=nb_iter_palm,
                update_right_to_left=True,
                residual_on_right=residual_on_right,
                track_objective_palm=track_objective_palm,
                delta_objective_error_threshold_palm=delta_objective_error_threshold_inner_palm,
                return_objective_function=track_objective_palm)
    else:
        _lambda_tmp, op_factors, U_centroids, objective_palm, nb_iter_palm = \
            palm4msa(
                arr_X_target=np.eye(K_nb_cluster) @ X_centroids_hat,
                lst_S_init=lst_factors,
                nb_factors=len(lst_factors),
                lst_projection_functions=lst_proj_op_by_fac_step[-1][
                    "finetune"],
                f_lambda_init=init_lambda * eye_norm,
                nb_iter=nb_iter_palm,
                update_right_to_left=True,
                track_objective=track_objective_palm,
                delta_objective_error_threshold=delta_objective_error_threshold_inner_palm)

    lst_factors = None  # safe assignment for debug

    _lambda = _lambda_tmp / eye_norm

    objective_function = np.ones(nb_iter) * -1
    lst_all_objective_functions_palm = []
    lst_all_objective_functions_palm.append(objective_palm)

    i_iter = 0
    delta_objective_error = np.inf
    while ((i_iter < nb_iter) and (delta_objective_error > delta_objective_error_threshold)):

        logger.info("Iteration Qmeans {}".format(i_iter))

        lst_factors_ = op_factors.get_list_of_factors()
        op_centroids = SparseFactors([lst_factors_[1] * _lambda] + lst_factors_[2:])

        ###########################
        # Cluster assignment step #
        ###########################

        indicator_vector, distances = assign_points_to_clusters(X_data, op_centroids, X_norms=X_data_norms)



        #######################
        # Cluster update step #
        #######################

        # get the number of observation in each cluster
        cluster_names, counts = np.unique(indicator_vector, return_counts=True)
        cluster_names_sorted = np.argsort(cluster_names)

        # Update centroid location using the newly (it happens in the assess_cluster_integrity function)
        # assigned data point classes
        # and check if all clusters still have points
        # and change the object X_centroids_hat in place if some cluster have lost points (biggest cluster)
        counts, cluster_names_sorted = update_clusters_with_integrity_check(X_data,
                                                                            X_data_norms,
                                                                            X_centroids_hat, # in place changes
                                                                            K_nb_cluster,
                                                                            counts,
                                                                            indicator_vector,
                                                                            distances,
                                                                            cluster_names,
                                                                            cluster_names_sorted)

        #################
        # PALM4MSA step #
        #################

        # create the diagonal of the sqrt of those counts
        diag_counts_sqrt_normalized = csr_matrix(
            (np.sqrt(counts[cluster_names_sorted] / nb_examples),
             (np.arange(K_nb_cluster), np.arange(K_nb_cluster))))
        diag_counts_sqrt = np.sqrt(counts[cluster_names_sorted])

        # set it as first factor
        op_factors.set_factor(0, diag_counts_sqrt_normalized)


        if hierarchical_inside:
            _lambda_tmp, op_factors, _, objective_palm, array_objective_hierarchical = \
                hierarchical_palm4msa(
                    arr_X_target=diag_counts_sqrt[:, None,] *  X_centroids_hat,
                    lst_S_init=op_factors.get_list_of_factors(),
                    lst_dct_projection_function=lst_proj_op_by_fac_step,
                    f_lambda_init=_lambda * np.sqrt(nb_examples),
                    nb_iter=nb_iter_palm,
                    update_right_to_left=True,
                    residual_on_right=residual_on_right,
                    return_objective_function=track_objective_palm,
                    track_objective_palm=track_objective_palm,
                    delta_objective_error_threshold_palm=delta_objective_error_threshold_inner_palm)

        else:
            _lambda_tmp, op_factors, _, objective_palm, nb_iter_palm = \
                palm4msa(arr_X_target=diag_counts_sqrt[:, None,] *  X_centroids_hat,
                         lst_S_init=op_factors.get_list_of_factors(),
                         nb_factors=op_factors.n_factors,
                         lst_projection_functions=lst_proj_op_by_fac_step[-1][
                             "finetune"],
                         f_lambda_init=_lambda * np.sqrt(nb_examples),
                         nb_iter=nb_iter_palm,
                         update_right_to_left=True,
                         track_objective=track_objective_palm,
                         delta_objective_error_threshold=delta_objective_error_threshold_inner_palm)

        lst_all_objective_functions_palm.append(objective_palm)

        _lambda = _lambda_tmp / np.sqrt(nb_examples)

        objective_function[i_iter] = compute_objective(X_data, op_centroids, indicator_vector)
        if i_iter >= 1:
            delta_objective_error = np.abs(objective_function[i_iter] - objective_function[i_iter-1]) / objective_function[i_iter-1]

        # todo vérifier que l'erreur absolue est plus petite que le threshold plusieurs fois d'affilee

        i_iter += 1

    lst_factors_ = op_factors.get_list_of_factors()
    op_centroids = SparseFactors([lst_factors_[1] * _lambda] + lst_factors_[2:])

    return objective_function[:i_iter], op_centroids, indicator_vector, lst_all_objective_functions_palm
Example #6
0
def main_compare_prod_vec():
    daiquiri.setup(level=logging.INFO)
    min_power2 = 2
    max_power2 = 18
    # max_power2 = 6
    n_vec = 100

    dims = np.array([int(2 ** i) for i in range(min_power2, max_power2)])
    nb_replicates = 3
    times_pyqalm_vec = np.empty((len(dims), nb_replicates))
    times_faust_vec = np.empty((len(dims), nb_replicates))
    for i_dim, dim in enumerate(dims):
        for i_seed in range(nb_replicates):
            F_faust = wht(dim)
            F_pyqalm = SparseFactors([F_faust.factors(i) for i in range(F_faust.numfactors())])

            rand_matrix_of_examples = np.random.randn(dim, n_vec)

            start_pyqalm_vec = time.time()
            for i in range(n_vec):
                r = F_pyqalm @ rand_matrix_of_examples[:, i]
            stop_pyqalm_vec = time.time()
            time_pyqalm_vec = stop_pyqalm_vec - start_pyqalm_vec
            times_pyqalm_vec[i_dim, i_seed] = time_pyqalm_vec

            start_faust_vec = time.time()
            for i in range(n_vec):
                r = F_faust @ rand_matrix_of_examples[:, i]
            stop_faust_vec = time.time()
            time_faust_vec = stop_faust_vec - start_faust_vec
            times_faust_vec[i_dim, i_seed] = time_faust_vec

    np.savez("results_time_vec_bis", faust=times_faust_vec, pyqalm=times_pyqalm_vec)

    fig = go.Figure()

    fig.add_trace(go.Scatter(
        x=dims,
        y=np.mean(times_faust_vec, axis=1),
        name="faust",
        error_y=dict(
            type='data',  # value of error bar given in data coordinates
            array=np.std(times_faust_vec, axis=1),
            visible=True)
    ))

    fig.add_trace(go.Scatter(
        x=dims,
        y=np.mean(times_pyqalm_vec, axis=1),
        name="pyqalm",
        error_y=dict(
            type='data',  # value of error bar given in data coordinates
            array=np.std(times_pyqalm_vec, axis=1),
            visible=True)
    ))
    fig.update_layout(title="vec")

    fig.update_xaxes(title_text='Dimension')
    fig.update_yaxes(title_text='Time')
    fig.show()
    fig.write_image("vec_bis.png")
Example #7
0
def palm4msa_fast4(arr_X_target: np.array,
                   lst_S_init: list,
                   nb_factors: int,
                   lst_projection_functions: list,
                   f_lambda_init: float,
                   nb_iter: int,
                   update_right_to_left=True,
                   track_objective=False,
                   delta_objective_error_threshold=1e-6):
    """
    lst S init contains factors in decreasing indexes (e.g: the order along which they are multiplied in the product).
        example: S5 S4 S3 S2 S1

    lst S [-j] = Sj

    :param arr_X_target: The target to approximate.
    :param lst_S_init: The initial list of sparse factors.
    :param nb_factors: The number of factors.
    :param lst_projection_functions: The projection function for each of the sparse factor.
    :param f_lambda_init: The initial scaling factor.
    :param nb_iter: The number of iteration before stopping.
    :param update_right_to_left: Tells the algorithm to update factors from right to left (S1 first)
    :param graphical_display: Make a graphical representation of results.
    :param track_objective: If true, the objective function is computed for each factor and not only at the end of each iteration.
    :param delta_objective_error_threshold: The normalized difference threshold between error at two successive iterations threshold below which the computation is stopped.

    :return: the sparse factorization but careful: the final X isn't multiplyed by lambda
    """
    logger.debug('Norme de arr_X_target: {}'.format(
        np.linalg.norm(arr_X_target, ord='fro')))
    # initialization
    f_lambda = f_lambda_init
    S_factors_op = SparseFactors(lst_S_init)

    assert np.all(S_factors_op.shape == arr_X_target.shape)
    assert S_factors_op.n_factors > 0
    assert S_factors_op.n_factors == nb_factors

    if track_objective:
        objective_function = np.ones(
            (nb_iter,
             nb_factors + 1)) * -1  # (nb_factors + 1) because of the lambda
    else:
        objective_function = np.ones((nb_iter, 1)) * -1

    if update_right_to_left:
        # range arguments: start, stop, step
        factor_number_generator = range(-1, -(nb_factors + 1), -1)
    else:
        factor_number_generator = range(0, nb_factors, 1)
    # main loop
    i_iter = 0
    delta_objective_error = np.inf

    init_vectors_norm_comp_L = [None] * nb_factors
    init_vectors_norm_comp_R = [None] * nb_factors

    while ((i_iter < nb_iter)
           and (delta_objective_error > delta_objective_error_threshold)):

        for machine_idx_fac, j in enumerate(factor_number_generator):
            if lst_projection_functions[j].__name__ == "constant_proj":
                if track_objective:
                    objective_function[
                        i_iter, machine_idx_fac] = compute_objective_function(
                            arr_X_target,
                            _f_lambda=f_lambda,
                            _lst_S=S_factors_op)
                    logger.debug(
                        "Iteration {}; Factor idx {}; Objective value {}".
                        format(i_iter, j, objective_function[i_iter,
                                                             machine_idx_fac]))
                continue

            L = S_factors_op.get_L(j)
            R = S_factors_op.get_R(-j - 1)
            # R = S_factors_op.get_R(nb_factors - j - 1)
            # print(nb_factors, L.n_factors+R.n_factors+1, L.n_factors,
            #       R.n_factors, j, -j-1)

            # compute minimum c value (according to paper)
            L_norm, init_vectors_norm_comp_L[j] = L.compute_spectral_norm(init_vector_eigs_v0=init_vectors_norm_comp_L[j]) \
                if L.n_factors > 0 else (1, init_vectors_norm_comp_L[j])
            R_norm, init_vectors_norm_comp_R[j] = R.compute_spectral_norm(init_vector_eigs_v0=init_vectors_norm_comp_R[j]) \
                if R.n_factors > 0 else (1, init_vectors_norm_comp_R[j])
            min_c_value = (f_lambda * L_norm * R_norm)**2  # lipsitchz constant
            # add epsilon because it is exclusive minimum
            c = min_c_value * 1.001
            logger.debug("Lipsitchz constant value: {}; c value: {}".format(
                min_c_value, c))
            # compute new factor value
            # todo check if it is not redundant to recompute the S_factors_op
            res = f_lambda * S_factors_op.compute_product() - arr_X_target
            # res_RH = R.dot(res.T).T if R.n_factors > 0 else res
            res_RH = S_factors_op.apply_RH(n_factors=-j - 1, X=res)
            # res_RH = S_factors_op.apply_RH(n_factors=nb_factors-j-1, X=res)
            LH_res_RH = S_factors_op.apply_LH(n_factors=j, X=res_RH)
            grad_step = 1. / c * f_lambda * LH_res_RH

            Sj = S_factors_op.get_factor(j)

            # normalize because all factors must have norm 1
            S_proj = lst_projection_functions[j](Sj - grad_step)
            S_proj = csr_matrix(S_proj)
            S_proj /= np.sqrt(S_proj.power(2).sum())

            S_factors_op.set_factor(j, S_proj)

            if track_objective:
                objective_function[
                    i_iter, machine_idx_fac] = compute_objective_function(
                        arr_X_target, _f_lambda=f_lambda, _lst_S=S_factors_op)
                logger.debug(
                    "Iteration {}; Factor idx {}; Objective value {}".format(
                        i_iter, j, objective_function[i_iter,
                                                      machine_idx_fac]))

        # re-compute the full factorisation
        # todo check if it is not redundant to recompute the S_factors_op
        arr_X_curr = S_factors_op.compute_product()

        # update lambda
        f_lambda = update_scaling_factor(X=arr_X_target, X_est=arr_X_curr)

        logger.debug("Lambda value: {}".format(f_lambda))

        objective_function[i_iter, -1] = \
            compute_objective_function(arr_X_target, _f_lambda=f_lambda,
                                       _lst_S=S_factors_op)

        logger.debug("Iteration {}; Objective value: {}".format(
            i_iter, objective_function[i_iter, -1]))

        if i_iter >= 1:
            delta_objective_error = np.abs(objective_function[i_iter, -1] -
                                           objective_function[i_iter - 1, -1]
                                           ) / objective_function[i_iter - 1,
                                                                  -1]
            logger.debug("Delta objective: {}".format(delta_objective_error))

        # TODO vérifier que l'erreur absolue est plus petite que le threshold plusieurs fois d'affilée

        i_iter += 1

    return f_lambda, S_factors_op, arr_X_curr, objective_function, i_iter
Example #8
0
def hierarchical_palm4msa(arr_X_target: np.array,
                          lst_S_init: list,
                          lst_dct_projection_function: list,
                          nb_iter: int,
                          f_lambda_init: float = 1,
                          residual_on_right: bool = True,
                          update_right_to_left=True,
                          track_objective_palm=False,
                          return_objective_function=False,
                          delta_objective_error_threshold_palm=1e-6):
    """


    :param arr_X_target:
    :param lst_S_init: The factors are given right to left. In all case.
    :param nb_keep_values:
    :param f_lambda_init:
    :param nb_iter:
    :param update_right_to_left: Way in which the factors are updated in the inner palm4msa algorithm. If update_right_to_left is True,
    the factors are updated right to left (e.g; the last factor in the list first). Otherwise the contrary.
    :param residual_on_right: During the split step, the residual can be computed as a right or left factor. If residual_on_right is True,
    the residuals are computed as right factors. We can also see this option as the update way for the hierarchical strategy:
    when the residual is computed on the right, it correspond to compute the last factor first (left to right according to the paper: the factor with the
    bigger number first)
    :return:
    """
    if not update_right_to_left:
        raise NotImplementedError  # todo voir pourquoi ça plante... mismatch dimension

    arr_residual = arr_X_target

    op_S_factors = SparseFactors(deepcopy(lst_S_init))
    nb_factors = op_S_factors.n_factors

    # check if lst_dct_param_projection_operator contains a list of dict with param for step split and finetune
    assert len(
        lst_dct_projection_function
    ) == nb_factors - 1, "Number of factor {} and number of constraints {} are different".format(
        len(lst_dct_projection_function), nb_factors - 1)
    assert all(
        len({"split", "finetune"}.difference(dct.keys())) == 0
        for dct in lst_dct_projection_function)

    f_lambda = f_lambda_init

    if return_objective_function:
        objective_function = np.empty((nb_factors, 3))
    else:
        objective_function = None

    lst_objectives = []

    # main loop
    for k in range(nb_factors - 1):
        lst_objective_split_fine_fac_k = []

        nb_factors_so_far = k + 1

        logger.info("Working on factor: {}".format(k))
        logger.info("Step split")

        ########################## Step split ##########################################################

        if return_objective_function:
            # compute objective before split step
            objective_function[k, 0] = compute_objective_function(
                arr_X_target, f_lambda, op_S_factors)

        # calcule decomposition en 2 du résidu précédent
        if k == 0:
            f_lambda_init_split = f_lambda_init
        else:
            f_lambda_init_split = 1.

        func_split_step_palm4msa = lambda lst_S_init: palm4msa(
            arr_X_target=arr_residual,
            lst_S_init=lst_S_init,  # eye for factor and zeros for residual
            nb_factors=2,
            lst_projection_functions=lst_dct_projection_function[k]["split"],
            # define constraints: ||0 = d pour T1; relaxed constraint on ||0 for T2
            f_lambda_init=f_lambda_init_split,
            nb_iter=nb_iter,
            update_right_to_left=update_right_to_left,
            track_objective=track_objective_palm,
            delta_objective_error_threshold=
            delta_objective_error_threshold_palm)

        if residual_on_right:
            op_S_factors_init = SparseFactors(lst_S_init[nb_factors_so_far:])
            residual_init = op_S_factors_init.compute_product(
            )  # todo I think this product can be prepared before and save computation
            lst_S_init_split_step = [lst_S_init[k], residual_init]
            f_lambda_prime, S_out, unscaled_residual_reconstruction, objective_palm_split, _ = \
                func_split_step_palm4msa(lst_S_init=lst_S_init_split_step)
            new_factor = S_out.get_factor(0)
            new_residual = S_out.get_factor(1)
            op_S_factors.set_factor(k, new_factor)

        else:
            op_S_factors_init = SparseFactors(lst_S_init[:-nb_factors_so_far])
            residual_init = op_S_factors_init.compute_product(
            )  # todo I think this product can be prepared before and save computation
            lst_S_init_split_step = [
                residual_init, lst_S_init[-nb_factors_so_far]
            ]
            f_lambda_prime, S_out, unscaled_residual_reconstruction, objective_palm_split, _ = \
                func_split_step_palm4msa(lst_S_init=lst_S_init_split_step)
            new_residual = S_out.get_factor(0)
            new_factor = S_out.get_factor(1)
            op_S_factors.set_factor(nb_factors - nb_factors_so_far, new_factor)

        if k == 0:
            f_lambda = f_lambda_prime
        else:
            f_lambda *= f_lambda_prime

        lst_objective_split_fine_fac_k.append(objective_palm_split)

        # get the k first elements [:k+1] and the next one (k+1)th as arr_residual (depend on the residual_on_right option)
        logger.info("Step finetuning")

        ########################## Step finetuning ##########################################################

        if return_objective_function:
            objective_function[k, 1] = compute_objective_function(
                arr_X_target, f_lambda, op_S_factors)

        func_fine_tune_step_palm4msa = lambda lst_S_init: palm4msa(
            arr_X_target=arr_X_target,
            lst_S_init=lst_S_init,
            nb_factors=nb_factors_so_far + 1,
            lst_projection_functions=lst_dct_projection_function[k]["finetune"
                                                                    ],
            f_lambda_init=f_lambda,
            nb_iter=nb_iter,
            update_right_to_left=update_right_to_left,
            track_objective=track_objective_palm,
            delta_objective_error_threshold=
            delta_objective_error_threshold_palm)

        if residual_on_right:
            lst_S_in = op_S_factors.get_list_of_factors()[:nb_factors_so_far]
            f_lambda, lst_S_out, _, objective_palm_fine, _ = \
                func_fine_tune_step_palm4msa(
                    lst_S_init=lst_S_in + [new_residual])
            for i in range(nb_factors_so_far):
                op_S_factors.set_factor(i, lst_S_out.get_factor(i))
            # TODO remove .toarray()?
            arr_residual = lst_S_out.get_factor(nb_factors_so_far).toarray()
        else:
            lst_S_in = op_S_factors.get_list_of_factors()[-nb_factors_so_far:]
            f_lambda, lst_S_out, _, objective_palm_fine, _ = \
                func_fine_tune_step_palm4msa(
                    lst_S_init=[new_residual] + lst_S_in)
            for i in range(nb_factors_so_far):
                op_S_factors.set_factor(-nb_factors_so_far + i,
                                        lst_S_out.get_factor(i + 1))
            # TODO remove .toarray()?
            arr_residual = lst_S_out.get_factor(0).toarray()

        lst_objective_split_fine_fac_k.append(objective_palm_fine)
        lst_objectives.append(tuple(lst_objective_split_fine_fac_k))

        if return_objective_function:
            objective_function[k, 2] = compute_objective_function(
                arr_X_target, f_lambda, op_S_factors)

    # last factor is residual of last palm4LED
    if residual_on_right:
        op_S_factors.set_factor(-1, arr_residual)
    else:
        op_S_factors.set_factor(0, arr_residual)

    if return_objective_function:
        objective_function[nb_factors - 1, :] = np.array(
            [compute_objective_function(arr_X_target, f_lambda, op_S_factors)
             ] * 3)

    arr_X_curr = f_lambda * op_S_factors.compute_product()

    return f_lambda, op_S_factors, arr_X_curr, lst_objectives, objective_function
Example #9
0
def qkmeans_minibatch(X_data: np.ndarray,
                      K_nb_cluster: int,
                      nb_iter: int,
                      nb_factors: int,
                      params_palm4msa: dict,
                      initialization: np.ndarray,
                      batch_size: int,
                      hierarchical_inside=False,
                      delta_objective_error_threshold=1e-6,
                      hierarchical_init=False):
    """
    :param X_data: The data matrix of n examples in dimensions d in shape (n, d).
    :param K_nb_cluster: The number of clusters to look for.
    :param nb_iter: The maximum number of iteration.
    :param nb_factors: The number of factors for the decomposition.
    :param initialization: The initial matrix of centroids not yet factorized.
    :param params_palm4msa: The dictionnary of parameters for the palm4msa algorithm.
    :param hierarchical_inside: Tell the algorithm if the hierarchical version of palm4msa should be used.
    :param delta_objective_error_threshold:
    :param hierarchical_init: Tells if the algorithm should make the initialization of sparse factors with the hierarchical version of palm or not.
    :param batch_size:  The size of each batch.
    
    :return:
    """

    assert K_nb_cluster == initialization.shape[0]

    logger.debug("Compute squared froebenius norm of data")
    X_data_norms = get_squared_froebenius_norm_line_wise_batch_by_batch(
        X_data, batch_size)

    nb_examples = X_data.shape[0]
    total_nb_of_minibatch = X_data.shape[0] // batch_size

    X_centroids_hat = copy.deepcopy(initialization)

    # ################################ INIT PALM4MSA ###############################
    logger.info("Initializing QKmeans with PALM algorithm")

    lst_factors = init_lst_factors(K_nb_cluster, X_centroids_hat.shape[1],
                                   nb_factors)
    eye_norm = np.sqrt(K_nb_cluster)

    ##########################
    # GET PARAMS OF PALM4MSA #
    ##########################
    init_lambda = params_palm4msa["init_lambda"]
    nb_iter_palm = params_palm4msa["nb_iter"]
    lst_proj_op_by_fac_step = params_palm4msa["lst_constraint_sets"]
    residual_on_right = params_palm4msa["residual_on_right"]
    delta_objective_error_threshold_inner_palm = params_palm4msa[
        "delta_objective_error_threshold"]
    track_objective_palm = params_palm4msa["track_objective"]

    ####################
    # INIT RUN OF PALM #
    ####################

    if hierarchical_inside or hierarchical_init:
        _lambda_tmp, op_factors, _, objective_palm, array_objective_hierarchical= \
            hierarchical_palm4msa(
                arr_X_target=np.eye(K_nb_cluster) @ X_centroids_hat,
                lst_S_init=lst_factors,
                lst_dct_projection_function=lst_proj_op_by_fac_step,
                f_lambda_init=init_lambda * eye_norm,
                nb_iter=nb_iter_palm,
                update_right_to_left=True,
                residual_on_right=residual_on_right,
                track_objective_palm=track_objective_palm,
                delta_objective_error_threshold_palm=delta_objective_error_threshold_inner_palm,
                return_objective_function=track_objective_palm)
    else:
        _lambda_tmp, op_factors, _, objective_palm, nb_iter_palm = \
            palm4msa(
                arr_X_target=np.eye(K_nb_cluster) @ X_centroids_hat,
                lst_S_init=lst_factors,
                nb_factors=len(lst_factors),
                lst_projection_functions=lst_proj_op_by_fac_step[-1][
                    "finetune"],
                f_lambda_init=init_lambda * eye_norm,
                nb_iter=nb_iter_palm,
                update_right_to_left=True,
                track_objective=track_objective_palm,
                delta_objective_error_threshold=delta_objective_error_threshold_inner_palm)

    # ################################################################

    lst_factors = None  # safe assignment for debug

    _lambda = _lambda_tmp / eye_norm

    objective_function = np.ones(nb_iter) * -1
    lst_all_objective_functions_palm = []
    lst_all_objective_functions_palm.append(objective_palm)

    i_iter = 0
    delta_objective_error = np.inf
    while ((i_iter < nb_iter)
           and (delta_objective_error > delta_objective_error_threshold)):
        logger.info("Iteration number {}/{}".format(i_iter, nb_iter))

        # Re-init palm factors for iteration
        lst_factors_ = op_factors.get_list_of_factors()
        op_centroids = SparseFactors([lst_factors_[1] * _lambda] +
                                     lst_factors_[2:])

        # Prepare next epoch
        full_count_vector = np.zeros(K_nb_cluster, dtype=int)
        full_indicator_vector = np.zeros(X_data.shape[0], dtype=int)

        X_centroids_hat = np.zeros_like(X_centroids_hat)

        for i_minibatch, example_batch_indexes in enumerate(
                DataGenerator(X_data,
                              batch_size=batch_size,
                              return_indexes=True)):
            logger.info(
                "Minibatch number {}/{}; Iteration number {}/{}".format(
                    i_minibatch, total_nb_of_minibatch, i_iter, nb_iter))
            example_batch = X_data[example_batch_indexes]
            example_batch_norms = X_data_norms[example_batch_indexes]

            ##########################
            # Update centroid oracle #
            ##########################

            indicator_vector, distances = assign_points_to_clusters(
                example_batch, op_centroids, X_norms=example_batch_norms)
            full_indicator_vector[example_batch_indexes] = indicator_vector

            cluster_names, counts = np.unique(indicator_vector,
                                              return_counts=True)
            count_vector = np.zeros(K_nb_cluster)
            count_vector[cluster_names] = counts

            full_count_vector = update_clusters(example_batch, X_centroids_hat,
                                                K_nb_cluster,
                                                full_count_vector,
                                                count_vector, indicator_vector)

        objective_function[i_iter] = compute_objective_by_batch(
            X_data, op_centroids, full_indicator_vector, batch_size)

        # inplace modification of X_centrois_hat and full_count_vector and full_indicator_vector
        check_cluster_integrity(X_data, X_centroids_hat, K_nb_cluster,
                                full_count_vector, full_indicator_vector)

        #########################
        # Do palm for iteration #
        #########################

        # create the diagonal of the sqrt of those counts
        diag_counts_sqrt_normalized = csr_matrix(
            (np.sqrt(full_count_vector / nb_examples),
             (np.arange(K_nb_cluster), np.arange(K_nb_cluster))))
        diag_counts_sqrt = np.sqrt(full_count_vector)

        # set it as first factor
        op_factors.set_factor(0, diag_counts_sqrt_normalized)

        if hierarchical_inside:
            _lambda_tmp, op_factors, _, objective_palm, array_objective_hierarchical = \
                hierarchical_palm4msa(
                    arr_X_target=diag_counts_sqrt[:, None,] *  X_centroids_hat,
                    lst_S_init=op_factors.get_list_of_factors(),
                    lst_dct_projection_function=lst_proj_op_by_fac_step,
                    f_lambda_init=_lambda * np.sqrt(nb_examples),
                    nb_iter=nb_iter_palm,
                    update_right_to_left=True,
                    residual_on_right=residual_on_right,
                    return_objective_function=track_objective_palm,
                    track_objective_palm=track_objective_palm,
                    delta_objective_error_threshold_palm=delta_objective_error_threshold_inner_palm)

        else:
            _lambda_tmp, op_factors, _, objective_palm, nb_iter_palm = \
                palm4msa(arr_X_target=diag_counts_sqrt[:, None,] *  X_centroids_hat,
                         lst_S_init=op_factors.get_list_of_factors(),
                         nb_factors=op_factors.n_factors,
                         lst_projection_functions=lst_proj_op_by_fac_step[-1][
                             "finetune"],
                         f_lambda_init=_lambda * np.sqrt(nb_examples),
                         nb_iter=nb_iter_palm,
                         update_right_to_left=True,
                         track_objective=track_objective_palm,
                         delta_objective_error_threshold=delta_objective_error_threshold_inner_palm)

        _lambda = _lambda_tmp / np.sqrt(nb_examples)

        ############################

        lst_all_objective_functions_palm.append(objective_palm)

        if i_iter >= 1:
            delta_objective_error = np.abs(objective_function[i_iter] -
                                           objective_function[i_iter - 1]
                                           ) / objective_function[i_iter - 1]

        # todo vérifier que l'erreur absolue est plus petite que le threshold plusieurs fois d'affilee

        i_iter += 1

    op_centroids = SparseFactors([lst_factors_[1] * _lambda] +
                                 lst_factors_[2:])

    return objective_function[:
                              i_iter], op_centroids, full_indicator_vector, lst_all_objective_functions_palm