def makeSummaryForBirthProposal(Dslice, curModel, curLPslice, curSSwhole=None, b_creationProposalName='bregmankmeans', targetUID=None, ktarget=None, newUIDs=None, LPkwargs=DefaultLPkwargs, lapFrac=0, batchID=0, seed=0, b_nRefineSteps=3, b_debugOutputDir=None, b_minNumAtomsForNewComp=None, b_doInitCompleteLP=1, b_cleanupWithMerge=1, b_method_initCoordAscent='fromprevious', vocabList=None, **kwargs): ''' Create summary that reassigns mass from target to Kfresh new comps. TODO support other options than bregman??? Returns ------- xSSslice : SuffStatBag Contains exact summaries for reassignment of target mass. * Total mass is equal to mass assigned to ktarget in curLPslice * Number of components is Kfresh Info : dict Contains info for detailed debugging of construction process. ''' # Parse input to decide which cluster to target # * targetUID is the unique ID of this cluster # * ktarget is its position in the current cluster ordering if targetUID is None: targetUID = curSSwhole.k2uid(ktarget) if ktarget is None: ktarget = curSSwhole.uid2k(targetUID) # START log for this birth proposal BLogger.pprint( 'Creating proposal for targetUID %s at lap %.2f batchID %d' % (targetUID, lapFrac, batchID)) # Grab vocabList, if available. if hasattr(Dslice, 'vocabList') and Dslice.vocabList is not None: vocabList = Dslice.vocabList # Parse input to decide where to save HTML output if b_debugOutputDir == 'None': b_debugOutputDir = None if b_debugOutputDir: BLogger.pprint('HTML output:' + b_debugOutputDir) # Create snapshot of current model comps plotCompsFromSS(curModel, curSSwhole, os.path.join(b_debugOutputDir, 'OrigComps.png'), vocabList=vocabList, compsToHighlight=[ktarget]) # Determine exactly how many new states we can make... xK = len(newUIDs) if xK + curSSwhole.K > kwargs['Kmax']: xK = kwargs['Kmax'] - curSSwhole.K newUIDs = newUIDs[:xK] if xK <= 1: errorMsg = 'Cancelled.' + \ 'Adding 2 or more states would exceed budget of %d comps.' % ( kwargs['Kmax']) BLogger.pprint(errorMsg) BLogger.pprint('') return None, dict(errorMsg=errorMsg) # Create suff stats for some new states xInitSStarget, Info = initSS_BregmanDiv( Dslice, curModel, curLPslice, K=xK, ktarget=ktarget, lapFrac=lapFrac, seed=seed + int(1000 * lapFrac), logFunc=BLogger.pprint, NiterForBregmanKMeans=kwargs['b_NiterForBregmanKMeans'], **kwargs) # EXIT EARLY: if proposal initialization fails (not enough data). if xInitSStarget is None: BLogger.pprint('Proposal initialization FAILED. ' + \ Info['errorMsg']) BLogger.pprint('') return None, Info # If here, we have a valid set of initial stats. xInitSStarget.setUIDs(newUIDs[:xInitSStarget.K]) if b_doInitCompleteLP: # Create valid whole-dataset clustering from hard init xInitSSslice, tempInfo = makeExpansionSSFromZ( Dslice=Dslice, curModel=curModel, curLPslice=curLPslice, ktarget=ktarget, xInitSS=xInitSStarget, atomType=Info['atomType'], targetZ=Info['targetZ'], chosenDataIDs=Info['chosenDataIDs'], **kwargs) Info.update(tempInfo) xSSslice = xInitSSslice else: xSSslice = xInitSStarget if b_debugOutputDir: plotCompsFromSS(curModel, xSSslice, os.path.join(b_debugOutputDir, 'NewComps_Init.png'), vocabList=vocabList) # Determine current model objective score curModelFWD = curModel.copy() curModelFWD.update_global_params(SS=curSSwhole) curLdict = curModelFWD.calc_evidence(SS=curSSwhole, todict=1) # Track proposal ELBOs as refinement improves things propLdictList = list() # Create initial proposal if b_doInitCompleteLP: propSS = curSSwhole.copy() propSS.transferMassFromExistingToExpansion(uid=targetUID, xSS=xSSslice) # Verify quality assert np.allclose(propSS.getCountVec().sum(), curSSwhole.getCountVec().sum()) propModel = curModel.copy() propModel.update_global_params(propSS) propLdict = propModel.calc_evidence(SS=propSS, todict=1) BLogger.pprint( "init %d/%d gainL % .3e propL % .3e curL % .3e" % (0, b_nRefineSteps, propLdict['Ltotal'] - curLdict['Ltotal'], propLdict['Ltotal'], curLdict['Ltotal'])) propLdictList.append(propLdict) docUsageByUID = dict() if curModel.getAllocModelName().count('HDP'): for k, uid in enumerate(xInitSStarget.uids): if 'targetZ' in Info: if Info['atomType'].count('doc'): initDocUsage_uid = np.sum(Info['targetZ'] == k) else: initDocUsage_uid = 0.0 for d in xrange(Dslice.nDoc): start = Dslice.doc_range[d] stop = Dslice.doc_range[d + 1] initDocUsage_uid += np.any( Info['targetZ'][start:stop] == k) else: initDocUsage_uid = 0.0 docUsageByUID[uid] = [initDocUsage_uid] # Create initial observation model xObsModel = curModel.obsModel.copy() if b_method_initCoordAscent == 'fromprevious' and 'xLPslice' in Info: xInitLPslice = Info['xLPslice'] else: xInitLPslice = None # Make a function to pretty-print counts as we refine the initialization pprintCountVec = BLogger.makeFunctionToPrettyPrintCounts(xSSslice) BLogger.pprint(" " + vec2str(xInitSStarget.uids)) pprintCountVec(xSSslice) # Log messages to describe the initialization. BLogger.pprint(' Running %d refinement iterations (--b_nRefineSteps)' % (b_nRefineSteps)) prevCountVec = xSSslice.getCountVec() didConvEarly = False convstep = 100 + b_nRefineSteps # Run several refinement steps. # Each step does a restricted local step to improve # the proposed cluster assignments. for rstep in range(b_nRefineSteps): # Update xObsModel xObsModel.update_global_params(xSSslice) # Restricted local step! # * xInitSS : specifies obs-model stats used for initialization xSSslice, refineInfo = summarizeRestrictedLocalStep( Dslice=Dslice, curModel=curModel, curLPslice=curLPslice, curSSwhole=curSSwhole, ktarget=ktarget, xUIDs=xSSslice.uids, xInitSS=xSSslice, xObsModel=xObsModel, xInitLPslice=xInitLPslice, LPkwargs=LPkwargs, nUpdateSteps=1, **kwargs) Info.update(refineInfo) # Get most recent xLPslice for initialization if b_method_initCoordAscent == 'fromprevious' and 'xLPslice' in Info: xInitLPslice = Info['xLPslice'] # On first step, show diagnostics for new states if rstep == 0: targetPi = refineInfo['emptyPi'] + refineInfo['xPiVec'].sum() BLogger.pprint( " target prob redistributed by policy %s (--b_method_xPi)" % (kwargs['b_method_xPi'])) msg = " pi[ktarget] before %.4f after %.4f." % ( targetPi, refineInfo['emptyPi']) BLogger.pprint(msg) BLogger.pprint(" pi[new comps]: " + \ vec2str( refineInfo['xPiVec'], width=6, minVal=0.0001)) logLPConvergenceDiagnostics(refineInfo, rstep=rstep, b_nRefineSteps=b_nRefineSteps) BLogger.pprint(" " + vec2str(xInitSStarget.uids)) # Show diagnostic counts in each fresh state pprintCountVec(xSSslice) # Write HTML debug info if b_debugOutputDir: plotCompsFromSS(curModel, xSSslice, os.path.join(b_debugOutputDir, 'NewComps_Step%d.png' % (rstep + 1)), vocabList=vocabList) propSS = curSSwhole.copy() propSS.transferMassFromExistingToExpansion(uid=targetUID, xSS=xSSslice) # Reordering only lifts score by small amount. Not worth it. # propSS.reorderComps(np.argsort(-1 * propSS.getCountVec())) propModel = curModel.copy() propModel.update_global_params(propSS) propLdict = propModel.calc_evidence(SS=propSS, todict=1) propSSsubset = xSSslice tmpModel = curModelFWD tmpModel.obsModel.update_global_params(propSSsubset) propLdata_subset = tmpModel.obsModel.calcELBO_Memoized( propSSsubset) curSSsubset = xSSslice.copy(includeELBOTerms=0) while curSSsubset.K > 1: curSSsubset.mergeComps(0, 1) tmpModel.obsModel.update_global_params(curSSsubset) curLdata_subset = tmpModel.obsModel.calcELBO_Memoized(curSSsubset) gainLdata_subset = propLdata_subset - curLdata_subset msg = \ "step %d/%d gainL % .3e propL % .3e curL % .3e" % ( rstep+1, b_nRefineSteps, propLdict['Ltotal'] - curLdict['Ltotal'], propLdict['Ltotal'], curLdict['Ltotal']) msg += " gainLdata_subset % .3e" % (gainLdata_subset) BLogger.pprint(msg) propLdictList.append(propLdict) if curModel.getAllocModelName().count('HDP'): docUsageVec = xSSslice.getSelectionTerm('DocUsageCount') for k, uid in enumerate(xSSslice.uids): docUsageByUID[uid].append(docUsageVec[k]) # If converged early and did the final refinement step if didConvEarly and rstep > convstep: break # Cleanup by deleting small clusters if rstep < b_nRefineSteps - 1: if rstep == b_nRefineSteps - 2 or didConvEarly: # After all but last step, # delete small (but not empty) comps minNumAtomsToStay = b_minNumAtomsForNewComp else: # Always remove empty clusters. They waste our time. minNumAtomsToStay = np.minimum(1, b_minNumAtomsForNewComp) xSSslice, xInitLPslice = cleanupDeleteSmallClusters( xSSslice, minNumAtomsToStay, xInitLPslice=xInitLPslice, pprintCountVec=pprintCountVec) # Decide if we have converged early if rstep < b_nRefineSteps - 2 and prevCountVec.size == xSSslice.K: if np.allclose(xSSslice.getCountVec(), prevCountVec, atol=0.5): # Converged. Jump directly to the merge phase! didConvEarly = True convstep = rstep # Cleanup by merging clusters if b_cleanupWithMerge and \ (rstep == b_nRefineSteps - 2 or didConvEarly): # Only cleanup on second-to-last pass, or if converged early Info['mergestep'] = rstep + 1 xSSslice, xInitLPslice = cleanupMergeClusters( xSSslice, curModel, obsSSkeys=xInitSStarget._Fields._FieldDims.keys(), vocabList=vocabList, pprintCountVec=pprintCountVec, xInitLPslice=xInitLPslice, b_debugOutputDir=b_debugOutputDir, **kwargs) prevCountVec = xSSslice.getCountVec().copy() Info['Kfinal'] = xSSslice.K if b_debugOutputDir: savefilename = os.path.join(b_debugOutputDir, 'ProposalTrace_ELBO.png') plotELBOtermsForProposal(curLdict, propLdictList, savefilename=savefilename) if curModel.getAllocModelName().count('HDP'): savefilename = os.path.join(b_debugOutputDir, 'ProposalTrace_DocUsage.png') plotDocUsageForProposal(docUsageByUID, savefilename=savefilename) # EXIT EARLY: error if we didn't create enough "big-enough" states. nnzCount = np.sum(xSSslice.getCountVec() >= b_minNumAtomsForNewComp) if nnzCount < 2: Info['errorMsg'] = \ "Could not create at least two comps" + \ " with mass >= %.1f (--%s)" % ( b_minNumAtomsForNewComp, 'b_minNumAtomsForNewComp') BLogger.pprint('Proposal build phase FAILED. ' + Info['errorMsg']) BLogger.pprint('') # Blank line return None, Info # If here, we have a valid proposal. # Need to verify mass conservation if hasattr(Dslice, 'word_count') and \ curModel.obsModel.DataAtomType.count('word') and \ curModel.getObsModelName().count('Mult'): origMass = np.inner(Dslice.word_count, curLPslice['resp'][:, ktarget]) else: if 'resp' in curLPslice: origMass = curLPslice['resp'][:, ktarget].sum() else: origMass = curLPslice['spR'][:, ktarget].sum() newMass = xSSslice.getCountVec().sum() assert np.allclose(newMass, origMass, atol=1e-6, rtol=0) BLogger.pprint('Proposal build phase DONE.' + \ ' Created %d candidate clusters.' % (Info['Kfinal'])) BLogger.pprint('') # Blank line return xSSslice, Info
def cleanupMergeClusters(xSSslice, curModel, xInitLPslice=None, obsSSkeys=None, vocabList=None, b_cleanupMaxNumMergeIters=3, b_cleanupMaxNumAcceptPerIter=1, b_mergeLam=None, b_debugOutputDir=None, pprintCountVec=None, **kwargs): ''' Merge all possible pairs of clusters that improve the Ldata objective. Returns ------- xSSslice : SuffStatBag May have fewer components than K. ''' xSSslice.removeELBOandMergeTerms() xSSslice.removeSelectionTerms() # Discard all fields unrelated to observation model reqFields = set() for key in obsSSkeys: reqFields.add(key) for key in xSSslice._Fields._FieldDims.keys(): if key not in reqFields: xSSslice.removeField(key) # For merges, we can crank up value of the topic-word prior hyperparameter, # to prioritize only care big differences in word counts across many terms tmpModel = curModel.copy() if b_mergeLam is not None: tmpModel.obsModel.Prior.lam[:] = b_mergeLam mergeID = 0 for trial in range(b_cleanupMaxNumMergeIters): tmpModel.obsModel.update_global_params(xSSslice) GainLdata = tmpModel.obsModel.calcHardMergeGap_AllPairs(xSSslice) triuIDs = np.triu_indices(xSSslice.K, 1) posLocs = np.flatnonzero(GainLdata[triuIDs] > 0) if posLocs.size == 0: # No merges to accept. Stop! break # Rank the positive pairs from largest to smallest sortIDs = np.argsort(-1 * GainLdata[triuIDs][posLocs]) posLocs = posLocs[sortIDs] usedUIDs = set() nAccept = 0 uidpairsToAccept = list() origidsToAccept = list() for loc in posLocs: kA = triuIDs[0][loc] kB = triuIDs[1][loc] uidA = xSSslice.uids[triuIDs[0][loc]] uidB = xSSslice.uids[triuIDs[1][loc]] if uidA in usedUIDs or uidB in usedUIDs: continue usedUIDs.add(uidA) usedUIDs.add(uidB) uidpairsToAccept.append((uidA, uidB)) origidsToAccept.append((kA, kB)) nAccept += 1 if nAccept >= b_cleanupMaxNumAcceptPerIter: break for posID, (uidA, uidB) in enumerate(uidpairsToAccept): mergeID += 1 kA, kB = origidsToAccept[posID] xSSslice.mergeComps(uidA=uidA, uidB=uidB) if xInitLPslice: xInitLPslice['DocTopicCount'][:, kA] += \ xInitLPslice['DocTopicCount'][:, kB] xInitLPslice['DocTopicCount'][:, kB] = -1 if b_debugOutputDir: savefilename = os.path.join(b_debugOutputDir, 'MergeComps_%d.png' % (mergeID)) # Show side-by-side topics bnpy.viz.PlotComps.plotCompsFromHModel( tmpModel, compListToPlot=[kA, kB], vocabList=vocabList, xlabels=[str(uidA), str(uidB)], ) bnpy.viz.PlotUtil.pylab.savefig(savefilename, pad_inches=0) if len(uidpairsToAccept) > 0: pprintCountVec(xSSslice, uidpairsToAccept=uidpairsToAccept) if xInitLPslice: badIDs = np.flatnonzero(xInitLPslice['DocTopicCount'][0, :] < 0) for kk in reversed(badIDs): xInitLPslice['DocTopicCount'] = np.delete( xInitLPslice['DocTopicCount'], kk, axis=1) if mergeID > 0 and b_debugOutputDir: tmpModel.obsModel.update_global_params(xSSslice) outpath = os.path.join(b_debugOutputDir, 'NewComps_AfterMerge.png') plotCompsFromSS( tmpModel, xSSslice, outpath, vocabList=vocabList, ) if xInitLPslice: assert xInitLPslice['DocTopicCount'].min() > -0.000001 assert xInitLPslice['DocTopicCount'].shape[1] == xSSslice.K return xSSslice, xInitLPslice
def makeSummaryForExistingBirthProposal( Dslice, curModel, curLPslice, curSSwhole=None, targetUID=None, ktarget=None, LPkwargs=DefaultLPkwargs, lapFrac=0, batchID=0, b_nRefineSteps=3, b_debugOutputDir=None, b_method_initCoordAscent='fromprevious', vocabList=None, **kwargs): ''' Create summary that reassigns mass from target given set of comps Given set of comps is a fixed proposal from a previously-seen batch. Returns ------- xSSslice : SuffStatBag Contains exact summaries for reassignment of target mass. * Total mass is equal to mass assigned to ktarget in curLPslice * Number of components is Kfresh Info : dict Contains info for detailed debugging of construction process. ''' if targetUID is None: targetUID = curSSwhole.uids(ktarget) if ktarget is None: ktarget = curSSwhole.uid2k(targetUID) # START log for this birth proposal BLogger.pprint( 'Extending previous birth for targetUID %s at lap %.2f batch %d' % (targetUID, lapFrac, batchID)) # Grab vocabList, if available. if hasattr(Dslice, 'vocabList') and Dslice.vocabList is not None: vocabList = Dslice.vocabList # Parse input to decide where to save HTML output if b_debugOutputDir == 'None': b_debugOutputDir = None if b_debugOutputDir: BLogger.pprint('HTML output:' + b_debugOutputDir) # Create snapshot of current model comps plotCompsFromSS(curModel, curSSwhole, os.path.join(b_debugOutputDir, 'OrigComps.png'), vocabList=vocabList, compsToHighlight=[ktarget]) assert targetUID in curSSwhole.propXSS xinitSS = curSSwhole.propXSS[targetUID] xK = xinitSS.K if xK + curSSwhole.K > kwargs['Kmax']: errorMsg = 'Cancelled.' + \ 'Adding 2 or more states would exceed budget of %d comps.' % ( kwargs['Kmax']) BLogger.pprint(errorMsg) BLogger.pprint('') return None, dict(errorMsg=errorMsg) # Log messages to describe the initialization. # Make a function to pretty-print counts as we refine the initialization pprintCountVec = BLogger.makeFunctionToPrettyPrintCounts(xinitSS) BLogger.pprint(' Using previous proposal with %d clusters %s.' % (xinitSS.K, '(--b_Kfresh=%d)' % kwargs['b_Kfresh'])) BLogger.pprint(" Initial uid/counts from previous proposal:") BLogger.pprint(' ' + vec2str(xinitSS.uids)) pprintCountVec(xinitSS) BLogger.pprint(' Running %d refinement iterations (--b_nRefineSteps)' % (b_nRefineSteps)) xSSinitPlusSlice = xinitSS.copy() if b_debugOutputDir: plotCompsFromSS(curModel, xinitSS, os.path.join(b_debugOutputDir, 'NewComps_Init.png'), vocabList=vocabList) # Determine current model objective score curModelFWD = curModel.copy() curModelFWD.update_global_params(SS=curSSwhole) curLdict = curModelFWD.calc_evidence(SS=curSSwhole, todict=1) # Track proposal ELBOs as refinement improves things propLdictList = list() docUsageByUID = dict() if curModel.getAllocModelName().count('HDP'): for k, uid in enumerate(xinitSS.uids): initDocUsage_uid = 0.0 docUsageByUID[uid] = [initDocUsage_uid] # Create initial observation model xObsModel = curModel.obsModel.copy() xInitLPslice = None Info = dict() # Run several refinement steps. # Each step does a restricted local step to improve # the proposed cluster assignments. nRefineSteps = np.maximum(1, b_nRefineSteps) for rstep in range(nRefineSteps): xObsModel.update_global_params(xSSinitPlusSlice) # Restricted local step! # * xInitSS : specifies obs-model stats used for initialization xSSslice, refineInfo = summarizeRestrictedLocalStep( Dslice=Dslice, curModel=curModel, curLPslice=curLPslice, curSSwhole=curSSwhole, ktarget=ktarget, xUIDs=xSSinitPlusSlice.uids, xObsModel=xObsModel, xInitSS=xSSinitPlusSlice, # first time in loop <= xinitSS xInitLPslice=xInitLPslice, LPkwargs=LPkwargs, **kwargs) xSSinitPlusSlice += xSSslice if rstep >= 1: xSSinitPlusSlice -= prevSSslice prevSSslice = xSSslice Info.update(refineInfo) # Show diagnostics for new states pprintCountVec(xSSslice) logLPConvergenceDiagnostics(refineInfo, rstep=rstep, b_nRefineSteps=b_nRefineSteps) # Get most recent xLPslice for initialization if b_method_initCoordAscent == 'fromprevious' and 'xLPslice' in Info: xInitLPslice = Info['xLPslice'] if b_debugOutputDir: plotCompsFromSS(curModel, xSSslice, os.path.join(b_debugOutputDir, 'NewComps_Step%d.png' % (rstep + 1)), vocabList=vocabList) propSS = curSSwhole.copy() propSS.transferMassFromExistingToExpansion(uid=targetUID, xSS=xSSslice) propModel = curModel.copy() propModel.update_global_params(propSS) propLdict = propModel.calc_evidence(SS=propSS, todict=1) BLogger.pprint( "step %d/%d gainL % .3e propL % .3e curL % .3e" % (rstep + 1, b_nRefineSteps, propLdict['Ltotal'] - curLdict['Ltotal'], propLdict['Ltotal'], curLdict['Ltotal'])) propLdictList.append(propLdict) if curModel.getAllocModelName().count('HDP'): docUsageVec = xSSslice.getSelectionTerm('DocUsageCount') for k, uid in enumerate(xSSslice.uids): docUsageByUID[uid].append(docUsageVec[k]) Info['Kfinal'] = xSSslice.K if b_debugOutputDir: savefilename = os.path.join(b_debugOutputDir, 'ProposalTrace_ELBO.png') plotELBOtermsForProposal(curLdict, propLdictList, savefilename=savefilename) if curModel.getAllocModelName().count('HDP'): savefilename = os.path.join(b_debugOutputDir, 'ProposalTrace_DocUsage.png') plotDocUsageForProposal(docUsageByUID, savefilename=savefilename) # If here, we have a valid proposal. # Need to verify mass conservation if hasattr(Dslice, 'word_count') and \ curModel.obsModel.DataAtomType.count('word') and \ curModel.getObsModelName().count('Mult'): origMass = np.inner(Dslice.word_count, curLPslice['resp'][:, ktarget]) else: if 'resp' in curLPslice: origMass = curLPslice['resp'][:, ktarget].sum() else: origMass = curLPslice['spR'][:, ktarget].sum() newMass = xSSslice.getCountVec().sum() assert np.allclose(newMass, origMass, atol=1e-6, rtol=0) BLogger.pprint('Proposal extension DONE. %d candidate clusters.' % (Info['Kfinal'])) BLogger.pprint('') return xSSslice, Info