def preselect_all_merge_candidates( curModel, SS, randstate=np.random, preselectroutine="random", mergePerLap=10, compIDs=list(), **kwargs ): """ Create and return a list of tuples, where each tuple represents a set of component IDs to try to merge Args -------- curModel : bnpy HModel SS : bnpy SuffStatBag. If None, defaults to random selection. randstate : numpy random number generator preselectroutine : name of procedure to select candidate pairs {'random', 'marglik', 'freshallpairs'} mergePerLap : int number of candidates to identify (may be less if K small) Returns -------- mPairList : list of component ID candidates for positions kA, kB each entry is a tuple of two integers """ nMergeTrials = mergePerLap K = curModel.allocModel.K if SS is None: # Handle first lap preselectroutine = "random" aList = list() bList = list() partnerIDs = set(range(K)) partnerIDs.difference_update(compIDs) if preselectroutine == "allpairsfromlist": compIDs = sorted(compIDs) L = len(compIDs) for aa in xrange(L - 1): for bb in xrange(aa + 1, L): aList.append(compIDs[aa]) bList.append(compIDs[bb]) aList = aList[:nMergeTrials] bList = bList[:nMergeTrials] elif preselectroutine == "allpairsfromlistbipartite": compIDs = sorted(compIDs) L = len(compIDs) for kA in compIDs: for kB in list(partnerIDs): aList.append(np.minimum(kA, kB)) bList.append(np.maximum(kA, kB)) aList = aList[:nMergeTrials] bList = bList[:nMergeTrials] elif preselectroutine == "bestnmatchfromlist": # Loop thru and find 3 best pairs for each comp in list compIDs = sorted(compIDs) L = len(compIDs) MTracker = MergeTracker(K) MSelector = MergePairSelector() cID = 0 trial = 0 hasPairs = MTracker.hasAvailablePairs while hasPairs() and cID < L and len(aList) < nMergeTrials: nPartners = 0 hasPartners = MTracker.hasAvailablePartnersForComp while hasPartners(compIDs[cID]) and nPartners < 3: kA, kB = MSelector.select_merge_components( curModel, SS, MTracker, mergename="marglik", kA=compIDs[cID], randstate=randstate ) MTracker.recordResult(kA=kA, kB=kB) aList.append(kA) bList.append(kB) nPartners += 1 trial += 1 cID += 1 # reindex aList, bList so we're likely to try all compIDs once aList = aList[::3] + aList[1::3] + aList[2::3] bList = bList[::3] + bList[1::3] + bList[2::3] aList = aList[:nMergeTrials] bList = bList[:nMergeTrials] # at this point, we've added each fresh comp once # continue to add random pairs to list until we've maxed out nMergeTrials while MTracker.hasAvailablePairs() and trial < nMergeTrials: kA, kB = MSelector.select_merge_components(curModel, SS, MTracker, mergename="marglik", randstate=randstate) MTracker.recordResult(kA=kA, kB=kB) aList.append(kA) bList.append(kB) trial += 1 elif preselectroutine == "freshbestmatch": compIDs = sorted(compIDs) L = len(compIDs) MTracker = MergeTracker(K) MSelector = MergePairSelector() trial = 0 while MTracker.hasAvailablePairs() and trial < np.minimum(L, nMergeTrials): kA = compIDs[trial] kA, kB = MSelector.select_merge_components( curModel, SS, MTracker, mergename="marglik", kA=kA, randstate=randstate ) MTracker.recordResult(kA=kA, kB=kB) aList.append(kA) bList.append(kB) trial += 1 # at this point, we've added each fresh comp once # continue to add to list until we've maxed out nMergeTrials while MTracker.hasAvailablePairs() and trial < nMergeTrials: kA, kB = MSelector.select_merge_components(curModel, SS, MTracker, mergename="marglik", randstate=randstate) MTracker.recordResult(kA=kA, kB=kB) aList.append(kA) bList.append(kB) trial += 1 elif preselectroutine == "random": MTracker = MergeTracker(K) MSelector = MergePairSelector() trial = 0 while MTracker.hasAvailablePairs() and trial < nMergeTrials: trial += 1 kA, kB = MSelector.select_merge_components(curModel, SS, MTracker, mergename="random", randstate=randstate) MTracker.recordResult(kA=kA, kB=kB) aList.append(kA) bList.append(kB) elif preselectroutine == "marglik": MSelector = MergePairSelector() M = np.zeros((K, K)) for kA in xrange(K): for kB in xrange(kA + 1, K): M[kA, kB] = MSelector._calcMScoreForCandidatePair(curModel, SS, kA, kB) # find the n largest non-zero entries flatM = M.flatten() bestIDs = np.argsort(flatM)[::-1] bestIDs = bestIDs[flatM[bestIDs] != 0] bestrs, bestcs = np.unravel_index(bestIDs, M.shape) assert np.all(bestrs < bestcs) aList = bestrs[:nMergeTrials].tolist() bList = bestcs[:nMergeTrials].tolist() elif preselectroutine == "marglikfromlistbipartite": """ consider best candidates for each comp in list, only partnering with nodes outside the list """ MTracker = MergeTracker(K) MSelector = MergePairSelector() M = np.zeros((K, K)) for kA in sorted(compIDs): for kB in list(partnerIDs): M[kA, kB] = MSelector._calcMScoreForCandidatePair(curModel, SS, kA, kB) # find the L largest non-zero entries bestIDs = np.argsort(M[kA, :])[::-1] bestIDs = bestIDs[M[kA, bestIDs] != 0] bestIDs = bestIDs[:3] for kB in bestIDs: MTracker.recordResult(kA=np.minimum(kA, kB), kB=np.maximum(kA, kB)) aList.append(np.minimum(kA, kB)) bList.append(np.maximum(kA, kB)) # reindex aList, bList so we're likely to try all compIDs once aList = aList[::3] + aList[1::3] + aList[2::3] bList = bList[::3] + bList[1::3] + bList[2::3] aList = aList[:nMergeTrials] bList = bList[:nMergeTrials] assert len(aList) == len(bList) assert len(aList) <= nMergeTrials return zip(aList, bList)
def preselect_all_merge_candidates(curModel, SS, randstate=np.random, preselectroutine='random', mergePerLap=10, compIDs=list(), **kwargs): ''' Create and return a list of tuples, where each tuple represents a set of component IDs to try to merge Args -------- curModel : bnpy HModel SS : bnpy SuffStatBag. If None, defaults to random selection. randstate : numpy random number generator preselectroutine : name of procedure to select candidate pairs {'random', 'marglik', 'freshallpairs'} mergePerLap : int number of candidates to identify (may be less if K small) Returns -------- mPairList : list of component ID candidates for positions kA, kB each entry is a tuple of two integers ''' nMergeTrials = mergePerLap K = curModel.allocModel.K if SS is None: # Handle first lap preselectroutine = 'random' aList = list() bList = list() partnerIDs = set(range(K)) partnerIDs.difference_update(compIDs) if preselectroutine == 'allpairsfromlist': compIDs = sorted(compIDs) L = len(compIDs) for aa in xrange(L - 1): for bb in xrange(aa + 1, L): aList.append(compIDs[aa]) bList.append(compIDs[bb]) aList = aList[:nMergeTrials] bList = bList[:nMergeTrials] elif preselectroutine == 'allpairsfromlistbipartite': compIDs = sorted(compIDs) L = len(compIDs) for kA in compIDs: for kB in list(partnerIDs): aList.append(np.minimum(kA, kB)) bList.append(np.maximum(kA, kB)) aList = aList[:nMergeTrials] bList = bList[:nMergeTrials] elif preselectroutine == 'bestnmatchfromlist': # Loop thru and find 3 best pairs for each comp in list compIDs = sorted(compIDs) L = len(compIDs) MTracker = MergeTracker(K) MSelector = MergePairSelector() cID = 0 trial = 0 hasPairs = MTracker.hasAvailablePairs while hasPairs() and cID < L and len(aList) < nMergeTrials: nPartners = 0 hasPartners = MTracker.hasAvailablePartnersForComp while hasPartners(compIDs[cID]) and nPartners < 3: kA, kB = MSelector.select_merge_components(curModel, SS, MTracker, mergename='marglik', kA=compIDs[cID], randstate=randstate) MTracker.recordResult(kA=kA, kB=kB) aList.append(kA) bList.append(kB) nPartners += 1 trial += 1 cID += 1 # reindex aList, bList so we're likely to try all compIDs once aList = aList[::3] + aList[1::3] + aList[2::3] bList = bList[::3] + bList[1::3] + bList[2::3] aList = aList[:nMergeTrials] bList = bList[:nMergeTrials] # at this point, we've added each fresh comp once # continue to add random pairs to list until we've maxed out nMergeTrials while MTracker.hasAvailablePairs() and trial < nMergeTrials: kA, kB = MSelector.select_merge_components(curModel, SS, MTracker, mergename='marglik', randstate=randstate) MTracker.recordResult(kA=kA, kB=kB) aList.append(kA) bList.append(kB) trial += 1 elif preselectroutine == 'freshbestmatch': compIDs = sorted(compIDs) L = len(compIDs) MTracker = MergeTracker(K) MSelector = MergePairSelector() trial = 0 while MTracker.hasAvailablePairs() and trial < np.minimum( L, nMergeTrials): kA = compIDs[trial] kA, kB = MSelector.select_merge_components(curModel, SS, MTracker, mergename='marglik', kA=kA, randstate=randstate) MTracker.recordResult(kA=kA, kB=kB) aList.append(kA) bList.append(kB) trial += 1 # at this point, we've added each fresh comp once # continue to add to list until we've maxed out nMergeTrials while MTracker.hasAvailablePairs() and trial < nMergeTrials: kA, kB = MSelector.select_merge_components(curModel, SS, MTracker, mergename='marglik', randstate=randstate) MTracker.recordResult(kA=kA, kB=kB) aList.append(kA) bList.append(kB) trial += 1 elif preselectroutine == 'random': MTracker = MergeTracker(K) MSelector = MergePairSelector() trial = 0 while MTracker.hasAvailablePairs() and trial < nMergeTrials: trial += 1 kA, kB = MSelector.select_merge_components(curModel, SS, MTracker, mergename='random', randstate=randstate) MTracker.recordResult(kA=kA, kB=kB) aList.append(kA) bList.append(kB) elif preselectroutine == 'marglik': MSelector = MergePairSelector() M = np.zeros((K, K)) for kA in xrange(K): for kB in xrange(kA + 1, K): M[kA, kB] = MSelector._calcMScoreForCandidatePair( curModel, SS, kA, kB) # find the n largest non-zero entries flatM = M.flatten() bestIDs = np.argsort(flatM)[::-1] bestIDs = bestIDs[flatM[bestIDs] != 0] bestrs, bestcs = np.unravel_index(bestIDs, M.shape) assert np.all(bestrs < bestcs) aList = bestrs[:nMergeTrials].tolist() bList = bestcs[:nMergeTrials].tolist() elif preselectroutine == 'marglikfromlistbipartite': ''' consider best candidates for each comp in list, only partnering with nodes outside the list ''' MTracker = MergeTracker(K) MSelector = MergePairSelector() M = np.zeros((K, K)) for kA in sorted(compIDs): for kB in list(partnerIDs): M[kA, kB] = MSelector._calcMScoreForCandidatePair( curModel, SS, kA, kB) # find the L largest non-zero entries bestIDs = np.argsort(M[kA, :])[::-1] bestIDs = bestIDs[M[kA, bestIDs] != 0] bestIDs = bestIDs[:3] for kB in bestIDs: MTracker.recordResult(kA=np.minimum(kA, kB), kB=np.maximum(kA, kB)) aList.append(np.minimum(kA, kB)) bList.append(np.maximum(kA, kB)) # reindex aList, bList so we're likely to try all compIDs once aList = aList[::3] + aList[1::3] + aList[2::3] bList = bList[::3] + bList[1::3] + bList[2::3] aList = aList[:nMergeTrials] bList = bList[:nMergeTrials] assert len(aList) == len(bList) assert len(aList) <= nMergeTrials return zip(aList, bList)
def run_many_merge_moves( hmodel, Data, SS, evBound=None, nMergeTrials=1, compList=list(), randstate=np.random, mPairIDs=None, **mergeKwArgs ): """ Run (potentially many) merge move on hmodel Args ------- hmodel Data SS nMergeTrials : number of merges to try compList : list of components to include in attempted merges randstate : numpy random number generator Returns ------- hmodel SS evBound MTracker """ nMergeTrials = np.maximum(nMergeTrials, len(compList)) MTracker = MergeTracker(SS.K) MSelector = MergePairSelector() # Exclude all pairs for which we did not compute the combined entropy Hz # Hz is always stored in KxK matrix. Pairs that were skipped have zeros. aList = list() bList = list() if SS.hasMergeTerm("ElogqZ"): Hz = SS.getMergeTerm("ElogqZ") for kA in xrange(SS.K): for kB in xrange(kA + 1, SS.K): if Hz[kA, kB] == 0: aList.append(kA) bList.append(kB) if len(aList) > 0: MTracker.addPairsToExclude(aList, bList) if evBound is None: newEv = hmodel.calc_evidence(SS=SS) else: newEv = evBound trialID = 0 shift = np.zeros(SS.K, dtype=np.int32) while trialID < nMergeTrials and MTracker.hasAvailablePairs(): oldEv = newEv if mPairIDs is not None: if len(mPairIDs) == 0: break kA, kB = mPairIDs.pop(0) try: MTracker.verifyPair(kA, kB) except AssertionError: print " AssertionError skipped with mPairIDs!", kA, kB continue elif len(compList) > 0: kA = compList.pop() if kA not in MTracker.getAvailableComps(): continue kB = None else: kA = None kB = None hmodel, SS, newEv, MoveInfo = run_merge_move( hmodel, Data, SS, oldEv, kA=kA, kB=kB, randstate=randstate, MSelector=MSelector, MTracker=MTracker, **mergeKwArgs ) if MoveInfo["didAccept"]: assert newEv >= oldEv if mPairIDs is not None: mPairIDs = _reindexCandidatePairsAfterAcceptedMerge(mPairIDs, kA, kB) trialID += 1 MTracker.recordResult(**MoveInfo) return hmodel, SS, newEv, MTracker
def run_many_merge_moves(hmodel, Data, SS, evBound=None, nMergeTrials=1, compList=list(), randstate=np.random, mPairIDs=None, **mergeKwArgs): ''' Run (potentially many) merge move on hmodel Args ------- hmodel Data SS nMergeTrials : number of merges to try compList : list of components to include in attempted merges randstate : numpy random number generator Returns ------- hmodel SS evBound MTracker ''' nMergeTrials = np.maximum(nMergeTrials, len(compList)) MTracker = MergeTracker(SS.K) MSelector = MergePairSelector() # Exclude all pairs for which we did not compute the combined entropy Hz # Hz is always stored in KxK matrix. Pairs that were skipped have zeros. aList = list() bList = list() if SS.hasMergeTerm('ElogqZ'): Hz = SS.getMergeTerm('ElogqZ') for kA in xrange(SS.K): for kB in xrange(kA + 1, SS.K): if Hz[kA, kB] == 0: aList.append(kA) bList.append(kB) if len(aList) > 0: MTracker.addPairsToExclude(aList, bList) if evBound is None: newEv = hmodel.calc_evidence(SS=SS) else: newEv = evBound trialID = 0 shift = np.zeros(SS.K, dtype=np.int32) while trialID < nMergeTrials and MTracker.hasAvailablePairs(): oldEv = newEv if mPairIDs is not None: if len(mPairIDs) == 0: break kA, kB = mPairIDs.pop(0) try: MTracker.verifyPair(kA, kB) except AssertionError: print ' AssertionError skipped with mPairIDs!', kA, kB continue elif len(compList) > 0: kA = compList.pop() if kA not in MTracker.getAvailableComps(): continue kB = None else: kA = None kB = None hmodel, SS, newEv, MoveInfo = run_merge_move(hmodel, Data, SS, oldEv, kA=kA, kB=kB, randstate=randstate, MSelector=MSelector, MTracker=MTracker, **mergeKwArgs) if MoveInfo['didAccept']: assert newEv >= oldEv if mPairIDs is not None: mPairIDs = _reindexCandidatePairsAfterAcceptedMerge( mPairIDs, kA, kB) trialID += 1 MTracker.recordResult(**MoveInfo) return hmodel, SS, newEv, MTracker