def solve(C, p, D, k, *, solve_dense, etas, is_kl_not_js, J, L):
    n = C.shape[1]

    best_divergence = math.inf
    solve_dense_cache = {}

    def cached_solve_dense(S):
        indices = frozenset(numpy.nonzero(S)[0])
        try:
            return solve_dense_cache[indices]
        except KeyError:
            y = solve_dense(C[:, S], p)
            solve_dense_cache[indices] = y
            return y

    for eta in etas:
        ys = warm.iterate(C=C,
                          p=p,
                          D=D,
                          eta=eta,
                          is_kl_not_js=is_kl_not_js,
                          q=None)

        for y in itertools.islice((y for i, y in enumerate(ys) if i in J),
                                  len(J)):
            S = numpy.full(n, False)
            S[sorting.argmaxs(y, k)] = True
            y = numpy.zeros(n)
            y[S] = cached_solve_dense(S)

            q = C[:, S] @ y[S]
            divergence = D(p, q)
            if divergence < best_divergence:
                best_y = y
                best_divergence = divergence

            for l in L:
                spaces = identification.shift(C=C, p=p, D=D, q=q)

                S[sorting.argmins(spaces, l)] = True
                y = numpy.zeros(n)
                y[S] = cached_solve_dense(S)

                S.fill(False)
                S[sorting.argmaxs(y, k)] = True
                y = numpy.zeros(n)
                y[S] = cached_solve_dense(S)

                q = C[:, S] @ y[S]
                divergence = D(p, q)
                if divergence < best_divergence:
                    best_y = y
                    best_divergence = divergence

    return best_y
Esempio n. 2
0
def solve(C, p, D, k, *, solve_dense, eta, is_kl_not_js, j, L):
    n = C.shape[1]

    best_y = warm_kl_like.solve(C,
                                p,
                                D,
                                k,
                                solve_dense=solve_dense,
                                eta=eta,
                                is_kl_not_js=is_kl_not_js,
                                j=j)
    S = best_y != 0
    q = C[:, S] @ best_y[S]
    best_divergence = D(p, q)

    for l in L:
        spaces = identification.shift(C=C, p=p, D=D, q=q)
        S[sorting.argmins(spaces, l)] = True

        y = numpy.zeros(n)
        y[S] = solve_dense(C[:, S], p)

        S.fill(False)
        S[sorting.argmaxs(y, k)] = True

        q = C[:, S] @ y[S]

    y = numpy.zeros(n)
    y[S] = solve_dense(C[:, S], p)
    divergence = D(p, C[:, S] @ y[S])
    if divergence < best_divergence:
        best_y = y

    return best_y
def solve(C, p, D, k, *, solve_dense, L):
    m, n = C.shape
    S = numpy.full(n, False)
    q = numpy.zeros(m)
    best_divergence = math.inf

    for l in L:
        spaces = identification.shift(C=C, p=p, D=D, q=q)
        S[sorting.argmins(spaces, l)] = True

        y = numpy.zeros(n)
        y[S] = solve_dense(C[:, S], p)

        S.fill(False)
        S[sorting.argmaxs(y, k)] = True

        y = numpy.zeros(n)
        y[S] = solve_dense(C[:, S], p)

        q = C[:, S] @ y[S]
        divergence = D(p, q)

        if divergence < best_divergence:
            best_y = y
            best_divergence = divergence

    return best_y
Esempio n. 4
0
def solve(C, p, D, k, *, solve_dense, eta, is_kl_not_js, j, L):
    n = C.shape[1]

    best_y = warm_kl_like.solve(C,
                                p,
                                D,
                                k,
                                solve_dense=solve_dense,
                                eta=eta,
                                is_kl_not_js=True,
                                j=j)
    S = best_y != 0
    q = C[:, S] @ best_y[S]
    best_divergence = D(p, q)

    for l in L:
        ys = warm.iterate(C=C,
                          p=p,
                          D=D,
                          eta=eta,
                          is_kl_not_js=is_kl_not_js,
                          q=q)
        y = next(y for i, y in enumerate(ys)
                 if numpy.count_nonzero(y) >= l or i >= j)
        S |= y != 0

        y = numpy.zeros(n)
        y[S] = solve_dense(C[:, S], p)

        S.fill(False)
        S[sorting.argmaxs(y, k)] = True

        q = C[:, S] @ y[S]

    y = numpy.zeros(n)
    y[S] = solve_dense(C[:, S], p)
    divergence = D(p, C[:, S] @ y[S])
    if divergence < best_divergence:
        best_y = y

    return best_y
def solve(C, p, D, k, *, solve_dense, L):
    m, n = C.shape
    S = numpy.full(n, False)
    q = numpy.zeros(m)

    for l in L:
        spaces = identification.shift(C=C, p=p, D=D, q=q)
        S[sorting.argmins(spaces, l)] = True

        y = numpy.zeros(n)
        y[S] = solve_dense(C[:, S], p)

        S.fill(False)
        S[sorting.argmaxs(y, k)] = True

        q = C[:, S] @ y[S]

    y = numpy.zeros(n)
    y[S] = solve_dense(C[:, S], p)

    return y
def test_argmaxs(values, count, expected_indices):
    actual_indices = sorting.argmaxs(values, count)

    assert all(numpy.sort(actual_indices) == expected_indices)