def pprint(msg, level=None, prefix='', linewidth=80): global Log global DEFAULTLEVEL if isinstance(msg, list): msgs = list() prefixes = list() for ii, m_ii in enumerate(msg): prefix_ii = prefix[ii] msgs_ii = split_str_into_fixed_width_lines(m_ii, linewidth=linewidth - len(prefix_ii)) msgs.extend(msgs_ii) prefixes.extend([prefix[ii] for i in range(len(msgs_ii))]) for ii in range(len(msgs)): pprint(prefixes[ii] + msgs[ii], level=level) return if DEFAULTLEVEL == 'print': print(msg) if Log is None: return if level is None: level = DEFAULTLEVEL if isinstance(level, str): if level.count('info'): level = logging.INFO elif level.count('debug'): level = logging.DEBUG Log.log(level, msg) for h in Log.handlers: h.flush() # Small bit of code to track recent messages # for debugging birth proposals. Used when dumping to HTML global RecentMessages if isinstance(RecentMessages, list): RecentMessages.append(msg)
def countvec2str(curN_K): if W is None: str_sum_w = ' '.join(['%7.0f' % (x) for x in curN_K]) else: assert np.allclose(curN_K.sum(), W.sum()) str_sum_w = ' '.join(['%7.2f' % (x) for x in curN_K]) return split_str_into_fixed_width_lines(str_sum_w, tostr=True)
def pprint(msg, level=None, prefix='', linewidth=80): global Log global DEFAULTLEVEL global RecentMessages if isinstance(msg, list): msgs = list() prefixes = list() for ii, m_ii in enumerate(msg): prefix_ii = prefix[ii] msgs_ii = split_str_into_fixed_width_lines(m_ii, linewidth=linewidth - len(prefix_ii)) msgs.extend(msgs_ii) prefixes.extend([prefix[ii] for i in range(len(msgs_ii))]) for ii in range(len(msgs)): pprint(prefixes[ii] + msgs[ii], level=level) return if level == 'print': print(msg) if Log is None: return if level is None: level = DEFAULTLEVEL if isinstance(level, str): if level.count('info'): level = logging.INFO elif level.count('debug'): level = logging.DEBUG Log.log(level, msg) if isinstance(RecentMessages, list): RecentMessages.append(msg)
def runKMeans_BregmanDiv(X, K, obsModel, W=None, Niter=100, seed=0, init='plusplus', smoothFracInit=1.0, smoothFrac=0, logFunc=None, eps=1e-10, setOneToPriorMean=0, distexp=1.0, assert_monotonic=True, **kwargs): ''' Run hard clustering algorithm to find K clusters. Returns ------- Z : 1D array, size N Mu : 2D array, size K x D Lscores : 1D array, size Niter ''' chosenZ, Mu, _, _ = initKMeans_BregmanDiv( X, K, obsModel, W=W, seed=seed, smoothFrac=smoothFracInit, distexp=distexp, setOneToPriorMean=setOneToPriorMean) # Make sure we update K to reflect the returned value. # initKMeans_BregmanDiv will return fewer than K clusters # in some edge cases, like when data matrix X has duplicate rows # and specified K is larger than the number of unique rows. K = len(Mu) assert K > 0 assert Niter >= 0 if Niter == 0: Z = -1 * np.ones(X.shape[0]) if chosenZ[0] == -1: Z[chosenZ[1:]] = np.arange(chosenZ.size - 1) else: Z[chosenZ] = np.arange(chosenZ.size) Lscores = list() prevN = np.zeros(K) for riter in range(Niter): Div = obsModel.calcSmoothedBregDiv(X=X, Mu=Mu, W=W, includeOnlyFastTerms=True, smoothFrac=smoothFrac, eps=eps) Z = np.argmin(Div, axis=1) Ldata = Div.min(axis=1).sum() Lprior = obsModel.calcBregDivFromPrior(Mu=Mu, smoothFrac=smoothFrac).sum() Lscore = Ldata + Lprior Lscores.append(Lscore) # Verify objective is monotonically increasing if assert_monotonic: try: # Test allows small positive increases that are # numerically indistinguishable from zero. Don't care about these. assert np.all(np.diff(Lscores) <= 1e-5) except AssertionError: msg = "iter %d: Lscore %.3e" % (riter, Lscore) msg += '\nIn the kmeans update loop of FromScratchBregman.py' msg += '\nLscores not monotonically decreasing...' if logFunc: logFunc(msg) else: print(msg) assert np.all(np.diff(Lscores) <= 1e-5) N = np.zeros(K) for k in range(K): if W is None: W_k = None N[k] = np.sum(Z == k) else: W_k = W[Z == k] N[k] = np.sum(W_k) if N[k] > 0: Mu[k] = obsModel.calcSmoothedMu(X[Z == k], W_k) else: Mu[k] = obsModel.calcSmoothedMu(X=None) if logFunc: logFunc("iter %d: Lscore %.3e" % (riter, Lscore)) if W is None: str_sum_w = ' '.join(['%7.0f' % (x) for x in N]) else: assert np.allclose(N.sum(), W.sum()) str_sum_w = ' '.join(['%7.2f' % (x) for x in N]) str_sum_w = split_str_into_fixed_width_lines(str_sum_w, tostr=True) logFunc(str_sum_w) if np.max(np.abs(N - prevN)) == 0: break prevN[:] = N uniqueZ = np.unique(Z) if Niter > 0: # In case a cluster was pushed to zero if uniqueZ.size < len(Mu): Mu = [Mu[k] for k in uniqueZ] else: # Without full pass through dataset, many items not assigned # which we indicated with Z value of -1 # Should ignore this when counting states uniqueZ = uniqueZ[uniqueZ >= 0] assert len(Mu) == uniqueZ.size return Z, Mu, np.asarray(Lscores)