Exemplo n.º 1
0
def pprint(msg, level=None, prefix='', linewidth=80):
    global Log
    global DEFAULTLEVEL
    if isinstance(msg, list):
        msgs = list()
        prefixes = list()
        for ii, m_ii in enumerate(msg):
            prefix_ii = prefix[ii]
            msgs_ii = split_str_into_fixed_width_lines(m_ii,
                                                       linewidth=linewidth -
                                                       len(prefix_ii))
            msgs.extend(msgs_ii)
            prefixes.extend([prefix[ii] for i in range(len(msgs_ii))])
        for ii in range(len(msgs)):
            pprint(prefixes[ii] + msgs[ii], level=level)
        return
    if DEFAULTLEVEL == 'print':
        print(msg)
    if Log is None:
        return
    if level is None:
        level = DEFAULTLEVEL
    if isinstance(level, str):
        if level.count('info'):
            level = logging.INFO
        elif level.count('debug'):
            level = logging.DEBUG
    Log.log(level, msg)
    for h in Log.handlers:
        h.flush()
    # Small bit of code to track recent messages
    # for debugging birth proposals. Used when dumping to HTML
    global RecentMessages
    if isinstance(RecentMessages, list):
        RecentMessages.append(msg)
Exemplo n.º 2
0
 def countvec2str(curN_K):
     if W is None:
         str_sum_w = ' '.join(['%7.0f' % (x) for x in curN_K])
     else:
         assert np.allclose(curN_K.sum(), W.sum())
         str_sum_w = ' '.join(['%7.2f' % (x) for x in curN_K])
     return split_str_into_fixed_width_lines(str_sum_w, tostr=True)
Exemplo n.º 3
0
def pprint(msg, level=None, prefix='', linewidth=80):
    global Log
    global DEFAULTLEVEL
    global RecentMessages
    if isinstance(msg, list):
        msgs = list()
        prefixes = list()
        for ii, m_ii in enumerate(msg):
            prefix_ii = prefix[ii]
            msgs_ii = split_str_into_fixed_width_lines(m_ii,
                                                       linewidth=linewidth -
                                                       len(prefix_ii))
            msgs.extend(msgs_ii)
            prefixes.extend([prefix[ii] for i in range(len(msgs_ii))])
        for ii in range(len(msgs)):
            pprint(prefixes[ii] + msgs[ii], level=level)
        return

    if level == 'print':
        print(msg)
    if Log is None:
        return
    if level is None:
        level = DEFAULTLEVEL
    if isinstance(level, str):
        if level.count('info'):
            level = logging.INFO
        elif level.count('debug'):
            level = logging.DEBUG
    Log.log(level, msg)
    if isinstance(RecentMessages, list):
        RecentMessages.append(msg)
Exemplo n.º 4
0
def runKMeans_BregmanDiv(X,
                         K,
                         obsModel,
                         W=None,
                         Niter=100,
                         seed=0,
                         init='plusplus',
                         smoothFracInit=1.0,
                         smoothFrac=0,
                         logFunc=None,
                         eps=1e-10,
                         setOneToPriorMean=0,
                         distexp=1.0,
                         assert_monotonic=True,
                         **kwargs):
    ''' Run hard clustering algorithm to find K clusters.

    Returns
    -------
    Z : 1D array, size N
    Mu : 2D array, size K x D
    Lscores : 1D array, size Niter
    '''
    chosenZ, Mu, _, _ = initKMeans_BregmanDiv(
        X,
        K,
        obsModel,
        W=W,
        seed=seed,
        smoothFrac=smoothFracInit,
        distexp=distexp,
        setOneToPriorMean=setOneToPriorMean)
    # Make sure we update K to reflect the returned value.
    # initKMeans_BregmanDiv will return fewer than K clusters
    # in some edge cases, like when data matrix X has duplicate rows
    # and specified K is larger than the number of unique rows.
    K = len(Mu)
    assert K > 0
    assert Niter >= 0
    if Niter == 0:
        Z = -1 * np.ones(X.shape[0])
        if chosenZ[0] == -1:
            Z[chosenZ[1:]] = np.arange(chosenZ.size - 1)
        else:
            Z[chosenZ] = np.arange(chosenZ.size)
    Lscores = list()
    prevN = np.zeros(K)
    for riter in range(Niter):
        Div = obsModel.calcSmoothedBregDiv(X=X,
                                           Mu=Mu,
                                           W=W,
                                           includeOnlyFastTerms=True,
                                           smoothFrac=smoothFrac,
                                           eps=eps)
        Z = np.argmin(Div, axis=1)
        Ldata = Div.min(axis=1).sum()
        Lprior = obsModel.calcBregDivFromPrior(Mu=Mu,
                                               smoothFrac=smoothFrac).sum()
        Lscore = Ldata + Lprior
        Lscores.append(Lscore)
        # Verify objective is monotonically increasing
        if assert_monotonic:
            try:
                # Test allows small positive increases that are
                # numerically indistinguishable from zero. Don't care about these.
                assert np.all(np.diff(Lscores) <= 1e-5)
            except AssertionError:
                msg = "iter %d: Lscore %.3e" % (riter, Lscore)
                msg += '\nIn the kmeans update loop of FromScratchBregman.py'
                msg += '\nLscores not monotonically decreasing...'
                if logFunc:
                    logFunc(msg)
                else:
                    print(msg)
                assert np.all(np.diff(Lscores) <= 1e-5)

        N = np.zeros(K)
        for k in range(K):
            if W is None:
                W_k = None
                N[k] = np.sum(Z == k)
            else:
                W_k = W[Z == k]
                N[k] = np.sum(W_k)
            if N[k] > 0:
                Mu[k] = obsModel.calcSmoothedMu(X[Z == k], W_k)
            else:
                Mu[k] = obsModel.calcSmoothedMu(X=None)
        if logFunc:
            logFunc("iter %d: Lscore %.3e" % (riter, Lscore))
            if W is None:
                str_sum_w = ' '.join(['%7.0f' % (x) for x in N])
            else:
                assert np.allclose(N.sum(), W.sum())
                str_sum_w = ' '.join(['%7.2f' % (x) for x in N])
            str_sum_w = split_str_into_fixed_width_lines(str_sum_w, tostr=True)
            logFunc(str_sum_w)
        if np.max(np.abs(N - prevN)) == 0:
            break
        prevN[:] = N

    uniqueZ = np.unique(Z)
    if Niter > 0:
        # In case a cluster was pushed to zero
        if uniqueZ.size < len(Mu):
            Mu = [Mu[k] for k in uniqueZ]
    else:
        # Without full pass through dataset, many items not assigned
        # which we indicated with Z value of -1
        # Should ignore this when counting states
        uniqueZ = uniqueZ[uniqueZ >= 0]
    assert len(Mu) == uniqueZ.size
    return Z, Mu, np.asarray(Lscores)