示例#1
0
def preselect_all_merge_candidates(
    curModel, SS, randstate=np.random, preselectroutine="random", mergePerLap=10, compIDs=list(), **kwargs
):
    """ 
      Create and return a list of tuples,
        where each tuple represents a set of component IDs to try to merge

      Args
      --------
      curModel : bnpy HModel 
      SS : bnpy SuffStatBag. If None, defaults to random selection.
      randstate : numpy random number generator
      preselectroutine : name of procedure to select candidate pairs
                          {'random', 'marglik', 'freshallpairs'}
      mergePerLap : int number of candidates to identify 
                      (may be less if K small)            

      Returns
      --------
      mPairList : list of component ID candidates for positions kA, kB
                    each entry is a tuple of two integers
  """
    nMergeTrials = mergePerLap
    K = curModel.allocModel.K
    if SS is None:  # Handle first lap
        preselectroutine = "random"
    aList = list()
    bList = list()

    partnerIDs = set(range(K))
    partnerIDs.difference_update(compIDs)
    if preselectroutine == "allpairsfromlist":
        compIDs = sorted(compIDs)
        L = len(compIDs)
        for aa in xrange(L - 1):
            for bb in xrange(aa + 1, L):
                aList.append(compIDs[aa])
                bList.append(compIDs[bb])
        aList = aList[:nMergeTrials]
        bList = bList[:nMergeTrials]
    elif preselectroutine == "allpairsfromlistbipartite":
        compIDs = sorted(compIDs)
        L = len(compIDs)
        for kA in compIDs:
            for kB in list(partnerIDs):
                aList.append(np.minimum(kA, kB))
                bList.append(np.maximum(kA, kB))
        aList = aList[:nMergeTrials]
        bList = bList[:nMergeTrials]
    elif preselectroutine == "bestnmatchfromlist":
        # Loop thru and find 3 best pairs for each comp in list
        compIDs = sorted(compIDs)
        L = len(compIDs)
        MTracker = MergeTracker(K)
        MSelector = MergePairSelector()
        cID = 0
        trial = 0
        hasPairs = MTracker.hasAvailablePairs
        while hasPairs() and cID < L and len(aList) < nMergeTrials:
            nPartners = 0
            hasPartners = MTracker.hasAvailablePartnersForComp
            while hasPartners(compIDs[cID]) and nPartners < 3:
                kA, kB = MSelector.select_merge_components(
                    curModel, SS, MTracker, mergename="marglik", kA=compIDs[cID], randstate=randstate
                )
                MTracker.recordResult(kA=kA, kB=kB)
                aList.append(kA)
                bList.append(kB)
                nPartners += 1
                trial += 1
            cID += 1
        # reindex aList, bList so we're likely to try all compIDs once
        aList = aList[::3] + aList[1::3] + aList[2::3]
        bList = bList[::3] + bList[1::3] + bList[2::3]
        aList = aList[:nMergeTrials]
        bList = bList[:nMergeTrials]
        # at this point, we've added each fresh comp once
        # continue to add random pairs to list until we've maxed out nMergeTrials
        while MTracker.hasAvailablePairs() and trial < nMergeTrials:
            kA, kB = MSelector.select_merge_components(curModel, SS, MTracker, mergename="marglik", randstate=randstate)
            MTracker.recordResult(kA=kA, kB=kB)
            aList.append(kA)
            bList.append(kB)
            trial += 1
    elif preselectroutine == "freshbestmatch":
        compIDs = sorted(compIDs)
        L = len(compIDs)
        MTracker = MergeTracker(K)
        MSelector = MergePairSelector()
        trial = 0
        while MTracker.hasAvailablePairs() and trial < np.minimum(L, nMergeTrials):
            kA = compIDs[trial]
            kA, kB = MSelector.select_merge_components(
                curModel, SS, MTracker, mergename="marglik", kA=kA, randstate=randstate
            )
            MTracker.recordResult(kA=kA, kB=kB)
            aList.append(kA)
            bList.append(kB)
            trial += 1
        # at this point, we've added each fresh comp once
        # continue to add to list until we've maxed out nMergeTrials
        while MTracker.hasAvailablePairs() and trial < nMergeTrials:
            kA, kB = MSelector.select_merge_components(curModel, SS, MTracker, mergename="marglik", randstate=randstate)
            MTracker.recordResult(kA=kA, kB=kB)
            aList.append(kA)
            bList.append(kB)
            trial += 1

    elif preselectroutine == "random":
        MTracker = MergeTracker(K)
        MSelector = MergePairSelector()
        trial = 0
        while MTracker.hasAvailablePairs() and trial < nMergeTrials:
            trial += 1
            kA, kB = MSelector.select_merge_components(curModel, SS, MTracker, mergename="random", randstate=randstate)
            MTracker.recordResult(kA=kA, kB=kB)
            aList.append(kA)
            bList.append(kB)
    elif preselectroutine == "marglik":
        MSelector = MergePairSelector()
        M = np.zeros((K, K))
        for kA in xrange(K):
            for kB in xrange(kA + 1, K):
                M[kA, kB] = MSelector._calcMScoreForCandidatePair(curModel, SS, kA, kB)
        # find the n largest non-zero entries
        flatM = M.flatten()
        bestIDs = np.argsort(flatM)[::-1]
        bestIDs = bestIDs[flatM[bestIDs] != 0]
        bestrs, bestcs = np.unravel_index(bestIDs, M.shape)
        assert np.all(bestrs < bestcs)
        aList = bestrs[:nMergeTrials].tolist()
        bList = bestcs[:nMergeTrials].tolist()
    elif preselectroutine == "marglikfromlistbipartite":
        """ consider best candidates for each comp in list,
            only partnering with nodes outside the list
    """
        MTracker = MergeTracker(K)
        MSelector = MergePairSelector()
        M = np.zeros((K, K))
        for kA in sorted(compIDs):
            for kB in list(partnerIDs):
                M[kA, kB] = MSelector._calcMScoreForCandidatePair(curModel, SS, kA, kB)
            # find the L largest non-zero entries
            bestIDs = np.argsort(M[kA, :])[::-1]
            bestIDs = bestIDs[M[kA, bestIDs] != 0]
            bestIDs = bestIDs[:3]
            for kB in bestIDs:
                MTracker.recordResult(kA=np.minimum(kA, kB), kB=np.maximum(kA, kB))
                aList.append(np.minimum(kA, kB))
                bList.append(np.maximum(kA, kB))

        # reindex aList, bList so we're likely to try all compIDs once
        aList = aList[::3] + aList[1::3] + aList[2::3]
        bList = bList[::3] + bList[1::3] + bList[2::3]
        aList = aList[:nMergeTrials]
        bList = bList[:nMergeTrials]
    assert len(aList) == len(bList)
    assert len(aList) <= nMergeTrials
    return zip(aList, bList)
示例#2
0
def preselect_all_merge_candidates(curModel,
                                   SS,
                                   randstate=np.random,
                                   preselectroutine='random',
                                   mergePerLap=10,
                                   compIDs=list(),
                                   **kwargs):
    ''' 
      Create and return a list of tuples,
        where each tuple represents a set of component IDs to try to merge

      Args
      --------
      curModel : bnpy HModel 
      SS : bnpy SuffStatBag. If None, defaults to random selection.
      randstate : numpy random number generator
      preselectroutine : name of procedure to select candidate pairs
                          {'random', 'marglik', 'freshallpairs'}
      mergePerLap : int number of candidates to identify 
                      (may be less if K small)            

      Returns
      --------
      mPairList : list of component ID candidates for positions kA, kB
                    each entry is a tuple of two integers
  '''
    nMergeTrials = mergePerLap
    K = curModel.allocModel.K
    if SS is None:  # Handle first lap
        preselectroutine = 'random'
    aList = list()
    bList = list()

    partnerIDs = set(range(K))
    partnerIDs.difference_update(compIDs)
    if preselectroutine == 'allpairsfromlist':
        compIDs = sorted(compIDs)
        L = len(compIDs)
        for aa in xrange(L - 1):
            for bb in xrange(aa + 1, L):
                aList.append(compIDs[aa])
                bList.append(compIDs[bb])
        aList = aList[:nMergeTrials]
        bList = bList[:nMergeTrials]
    elif preselectroutine == 'allpairsfromlistbipartite':
        compIDs = sorted(compIDs)
        L = len(compIDs)
        for kA in compIDs:
            for kB in list(partnerIDs):
                aList.append(np.minimum(kA, kB))
                bList.append(np.maximum(kA, kB))
        aList = aList[:nMergeTrials]
        bList = bList[:nMergeTrials]
    elif preselectroutine == 'bestnmatchfromlist':
        # Loop thru and find 3 best pairs for each comp in list
        compIDs = sorted(compIDs)
        L = len(compIDs)
        MTracker = MergeTracker(K)
        MSelector = MergePairSelector()
        cID = 0
        trial = 0
        hasPairs = MTracker.hasAvailablePairs
        while hasPairs() and cID < L and len(aList) < nMergeTrials:
            nPartners = 0
            hasPartners = MTracker.hasAvailablePartnersForComp
            while hasPartners(compIDs[cID]) and nPartners < 3:
                kA, kB = MSelector.select_merge_components(curModel,
                                                           SS,
                                                           MTracker,
                                                           mergename='marglik',
                                                           kA=compIDs[cID],
                                                           randstate=randstate)
                MTracker.recordResult(kA=kA, kB=kB)
                aList.append(kA)
                bList.append(kB)
                nPartners += 1
                trial += 1
            cID += 1
        # reindex aList, bList so we're likely to try all compIDs once
        aList = aList[::3] + aList[1::3] + aList[2::3]
        bList = bList[::3] + bList[1::3] + bList[2::3]
        aList = aList[:nMergeTrials]
        bList = bList[:nMergeTrials]
        # at this point, we've added each fresh comp once
        # continue to add random pairs to list until we've maxed out nMergeTrials
        while MTracker.hasAvailablePairs() and trial < nMergeTrials:
            kA, kB = MSelector.select_merge_components(curModel,
                                                       SS,
                                                       MTracker,
                                                       mergename='marglik',
                                                       randstate=randstate)
            MTracker.recordResult(kA=kA, kB=kB)
            aList.append(kA)
            bList.append(kB)
            trial += 1
    elif preselectroutine == 'freshbestmatch':
        compIDs = sorted(compIDs)
        L = len(compIDs)
        MTracker = MergeTracker(K)
        MSelector = MergePairSelector()
        trial = 0
        while MTracker.hasAvailablePairs() and trial < np.minimum(
                L, nMergeTrials):
            kA = compIDs[trial]
            kA, kB = MSelector.select_merge_components(curModel,
                                                       SS,
                                                       MTracker,
                                                       mergename='marglik',
                                                       kA=kA,
                                                       randstate=randstate)
            MTracker.recordResult(kA=kA, kB=kB)
            aList.append(kA)
            bList.append(kB)
            trial += 1
        # at this point, we've added each fresh comp once
        # continue to add to list until we've maxed out nMergeTrials
        while MTracker.hasAvailablePairs() and trial < nMergeTrials:
            kA, kB = MSelector.select_merge_components(curModel,
                                                       SS,
                                                       MTracker,
                                                       mergename='marglik',
                                                       randstate=randstate)
            MTracker.recordResult(kA=kA, kB=kB)
            aList.append(kA)
            bList.append(kB)
            trial += 1

    elif preselectroutine == 'random':
        MTracker = MergeTracker(K)
        MSelector = MergePairSelector()
        trial = 0
        while MTracker.hasAvailablePairs() and trial < nMergeTrials:
            trial += 1
            kA, kB = MSelector.select_merge_components(curModel,
                                                       SS,
                                                       MTracker,
                                                       mergename='random',
                                                       randstate=randstate)
            MTracker.recordResult(kA=kA, kB=kB)
            aList.append(kA)
            bList.append(kB)
    elif preselectroutine == 'marglik':
        MSelector = MergePairSelector()
        M = np.zeros((K, K))
        for kA in xrange(K):
            for kB in xrange(kA + 1, K):
                M[kA, kB] = MSelector._calcMScoreForCandidatePair(
                    curModel, SS, kA, kB)
        # find the n largest non-zero entries
        flatM = M.flatten()
        bestIDs = np.argsort(flatM)[::-1]
        bestIDs = bestIDs[flatM[bestIDs] != 0]
        bestrs, bestcs = np.unravel_index(bestIDs, M.shape)
        assert np.all(bestrs < bestcs)
        aList = bestrs[:nMergeTrials].tolist()
        bList = bestcs[:nMergeTrials].tolist()
    elif preselectroutine == 'marglikfromlistbipartite':
        ''' consider best candidates for each comp in list,
            only partnering with nodes outside the list
    '''
        MTracker = MergeTracker(K)
        MSelector = MergePairSelector()
        M = np.zeros((K, K))
        for kA in sorted(compIDs):
            for kB in list(partnerIDs):
                M[kA, kB] = MSelector._calcMScoreForCandidatePair(
                    curModel, SS, kA, kB)
            # find the L largest non-zero entries
            bestIDs = np.argsort(M[kA, :])[::-1]
            bestIDs = bestIDs[M[kA, bestIDs] != 0]
            bestIDs = bestIDs[:3]
            for kB in bestIDs:
                MTracker.recordResult(kA=np.minimum(kA, kB),
                                      kB=np.maximum(kA, kB))
                aList.append(np.minimum(kA, kB))
                bList.append(np.maximum(kA, kB))

        # reindex aList, bList so we're likely to try all compIDs once
        aList = aList[::3] + aList[1::3] + aList[2::3]
        bList = bList[::3] + bList[1::3] + bList[2::3]
        aList = aList[:nMergeTrials]
        bList = bList[:nMergeTrials]
    assert len(aList) == len(bList)
    assert len(aList) <= nMergeTrials
    return zip(aList, bList)
示例#3
0
def run_many_merge_moves(
    hmodel, Data, SS, evBound=None, nMergeTrials=1, compList=list(), randstate=np.random, mPairIDs=None, **mergeKwArgs
):
    """ Run (potentially many) merge move on hmodel

      Args
      -------
      hmodel
      Data
      SS
      nMergeTrials : number of merges to try
      compList : list of components to include in attempted merges
      randstate : numpy random number generator

      Returns
      -------
      hmodel
      SS
      evBound
      MTracker
  """
    nMergeTrials = np.maximum(nMergeTrials, len(compList))

    MTracker = MergeTracker(SS.K)
    MSelector = MergePairSelector()

    # Exclude all pairs for which we did not compute the combined entropy Hz
    #  Hz is always stored in KxK matrix. Pairs that were skipped have zeros.
    aList = list()
    bList = list()
    if SS.hasMergeTerm("ElogqZ"):
        Hz = SS.getMergeTerm("ElogqZ")
        for kA in xrange(SS.K):
            for kB in xrange(kA + 1, SS.K):
                if Hz[kA, kB] == 0:
                    aList.append(kA)
                    bList.append(kB)
    if len(aList) > 0:
        MTracker.addPairsToExclude(aList, bList)

    if evBound is None:
        newEv = hmodel.calc_evidence(SS=SS)
    else:
        newEv = evBound

    trialID = 0
    shift = np.zeros(SS.K, dtype=np.int32)
    while trialID < nMergeTrials and MTracker.hasAvailablePairs():
        oldEv = newEv

        if mPairIDs is not None:
            if len(mPairIDs) == 0:
                break
            kA, kB = mPairIDs.pop(0)
            try:
                MTracker.verifyPair(kA, kB)
            except AssertionError:
                print "  AssertionError skipped with mPairIDs!", kA, kB
                continue
        elif len(compList) > 0:
            kA = compList.pop()
            if kA not in MTracker.getAvailableComps():
                continue
            kB = None
        else:
            kA = None
            kB = None

        hmodel, SS, newEv, MoveInfo = run_merge_move(
            hmodel,
            Data,
            SS,
            oldEv,
            kA=kA,
            kB=kB,
            randstate=randstate,
            MSelector=MSelector,
            MTracker=MTracker,
            **mergeKwArgs
        )
        if MoveInfo["didAccept"]:
            assert newEv >= oldEv
            if mPairIDs is not None:
                mPairIDs = _reindexCandidatePairsAfterAcceptedMerge(mPairIDs, kA, kB)
        trialID += 1
        MTracker.recordResult(**MoveInfo)

    return hmodel, SS, newEv, MTracker
示例#4
0
def run_many_merge_moves(hmodel,
                         Data,
                         SS,
                         evBound=None,
                         nMergeTrials=1,
                         compList=list(),
                         randstate=np.random,
                         mPairIDs=None,
                         **mergeKwArgs):
    ''' Run (potentially many) merge move on hmodel

      Args
      -------
      hmodel
      Data
      SS
      nMergeTrials : number of merges to try
      compList : list of components to include in attempted merges
      randstate : numpy random number generator

      Returns
      -------
      hmodel
      SS
      evBound
      MTracker
  '''
    nMergeTrials = np.maximum(nMergeTrials, len(compList))

    MTracker = MergeTracker(SS.K)
    MSelector = MergePairSelector()

    # Exclude all pairs for which we did not compute the combined entropy Hz
    #  Hz is always stored in KxK matrix. Pairs that were skipped have zeros.
    aList = list()
    bList = list()
    if SS.hasMergeTerm('ElogqZ'):
        Hz = SS.getMergeTerm('ElogqZ')
        for kA in xrange(SS.K):
            for kB in xrange(kA + 1, SS.K):
                if Hz[kA, kB] == 0:
                    aList.append(kA)
                    bList.append(kB)
    if len(aList) > 0:
        MTracker.addPairsToExclude(aList, bList)

    if evBound is None:
        newEv = hmodel.calc_evidence(SS=SS)
    else:
        newEv = evBound

    trialID = 0
    shift = np.zeros(SS.K, dtype=np.int32)
    while trialID < nMergeTrials and MTracker.hasAvailablePairs():
        oldEv = newEv

        if mPairIDs is not None:
            if len(mPairIDs) == 0:
                break
            kA, kB = mPairIDs.pop(0)
            try:
                MTracker.verifyPair(kA, kB)
            except AssertionError:
                print '  AssertionError skipped with mPairIDs!', kA, kB
                continue
        elif len(compList) > 0:
            kA = compList.pop()
            if kA not in MTracker.getAvailableComps():
                continue
            kB = None
        else:
            kA = None
            kB = None

        hmodel, SS, newEv, MoveInfo = run_merge_move(hmodel,
                                                     Data,
                                                     SS,
                                                     oldEv,
                                                     kA=kA,
                                                     kB=kB,
                                                     randstate=randstate,
                                                     MSelector=MSelector,
                                                     MTracker=MTracker,
                                                     **mergeKwArgs)
        if MoveInfo['didAccept']:
            assert newEv >= oldEv
            if mPairIDs is not None:
                mPairIDs = _reindexCandidatePairsAfterAcceptedMerge(
                    mPairIDs, kA, kB)
        trialID += 1
        MTracker.recordResult(**MoveInfo)

    return hmodel, SS, newEv, MTracker