Example #1
0
def logLPConvergenceDiagnostics(refineInfo, rstep=0, b_nRefineSteps=0):
    if 'xLPslice' not in refineInfo:
        return
    xLPslice = refineInfo['xLPslice']
    if '_maxDiff' not in xLPslice:
        return

    msg = " LP info "
    #msg = "step %d/%d " % (rstep + 1, b_nRefineSteps)
    target_docs = np.flatnonzero(xLPslice['_maxDiff'] >= 0)
    if target_docs.size == 0:
        BLogger.pprint(msg + "No docs with active local step.")
        return

    msg += "nCAIters "
    for p in [0, 10, 50, 90, 100]:
        if p > 0:
            msg += "|"
        ip = np.percentile(xLPslice['_nIters'][target_docs], p)
        msg += " %3d%% %7d" % (p, ip)
    msg += "\n         Ndiff    "
    for p in [0, 10, 50, 90, 100]:
        if p > 0:
            msg += "|"
        md = np.percentile(xLPslice['_maxDiff'][target_docs], p)
        msg += " %3d%% %7.3f" % (p, md)
    BLogger.pprint(msg)
Example #2
0
def makeSummaryForBirthProposal_HTMLWrapper(Dslice, curModel, curLPslice,
                                            **kwargs):
    ''' Thin wrapper around makeSummaryForBirthProposal that produces HTML.

    Will produce HTML output regardless of if makeSummaryForBirthProposal
    succeeds or if it fails somewhere the construction process.

    Returns
    -------
    xSSslice : SuffStatBag
        Contains exact summaries for reassignment of target mass.
        * Total mass is equal to mass assigned to ktarget in curLPslice
        * Number of components is Kfresh
    Info : dict
        Contains info for detailed debugging of construction process.
    '''
    targetUID = kwargs['targetUID']
    BLogger.startUIDSpecificLog(kwargs['targetUID'])

    # Make an output directory for HTML
    if kwargs['b_debugWriteHTML']:
        kwargs['b_debugOutputDir'] = createBirthProposalHTMLOutputDir(**kwargs)
    else:
        if 'b_debugOutputDir' in kwargs:
            if kwargs['b_debugOutputDir'].lower() == 'none':
                kwargs['b_debugOutputDir'] = None

    doExtendExistingProposal = False
    if 'curSSwhole' in kwargs:
        curSSwhole = kwargs['curSSwhole']
        if hasattr(curSSwhole, 'propXSS'):
            if targetUID in curSSwhole.propXSS:
                doExtendExistingProposal = True

    if doExtendExistingProposal:
        xSSslice, DebugInfo = makeSummaryForExistingBirthProposal(
            Dslice, curModel, curLPslice, **kwargs)
    else:
        xSSslice, DebugInfo = makeSummaryForBirthProposal(
            Dslice, curModel, curLPslice, **kwargs)

    # Write output to HTML
    if 'b_debugOutputDir' in kwargs and kwargs['b_debugOutputDir']:
        htmlstr = makeSingleProposalHTMLStr(DebugInfo, **kwargs)
        htmlfilepath = os.path.join(kwargs['b_debugOutputDir'], 'index.html')
        with open(htmlfilepath, 'w') as f:
            f.write(htmlstr)
    BLogger.stopUIDSpecificLog(kwargs['targetUID'])
    return xSSslice, DebugInfo
Example #3
0
    print('korig %d' % (korig))
    # Determine what is hiding inside of it that shouldnt be
    mask = AZ == ktarget
    nTarget = np.sum(mask)
    print('%d total atoms assigned to ktarget...' % (nTarget))
    trueLabels = np.asarray(np.unique(Zref[mask]), np.int32)
    for ll in trueLabels:
        nTrue = np.sum(Zref[mask] == ll)
        print('%d/%d should have true label %d: %s' %
              (nTrue, nTarget, ll, chr(65 + ll)))
    return korig


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('taskoutpath', type=str)
    parser.add_argument('--lap', type=float, default=None)
    parser.add_argument('--lapFrac', type=float, default=None)
    parser.add_argument('--outputdir', type=str, default='/tmp/')
    parser.add_argument('--targetUID', type=int, default=0)
    parser.add_argument('--batchID', type=int, default=None)
    for key, val in list(DefaultBirthArgs.items()):
        parser.add_argument('--' + key, type=type(val), default=None)
    args = parser.parse_args()

    BLogger.configure(args.outputdir,
                      doSaveToDisk=0,
                      doWriteStdOut=1,
                      stdoutLevel=0)
    tryBirthForTask(**args.__dict__)
Example #4
0
def _run_task_internal(jobname, taskid, nTask, ReqArgs, KwArgs, UnkArgs,
                       dataName, allocModelName, obsModelName, algName,
                       doSaveToDisk, doWriteStdOut):
    """ Internal method (should never be called by end-user!)
        Executes learning for a particular job and particular taskid.

        Returns
        -------
        hmodel : bnpy HModel, fit to the data
        LP : Local parameter (LP) dict for the specific dataset
        RunInfo : dict of information about the run, with fields
        - 'loss' : final loss value for algorithm
        - 'loss_history' : vector of loss values over time
    """
    # Make shallow copies of input dicts, so we any modifications here
    # do not return to the caller.
    ReqArgs = dict(**ReqArgs)
    KwArgs = dict(**KwArgs)
    UnkArgs = dict(**UnkArgs)

    algseed = createUniqueRandomSeed(jobname, taskID=taskid)
    dataorderseed = createUniqueRandomSeed('', taskID=taskid)
    KwArgs[algName]['algseed'] = algseed
    KwArgs[algName]['dataorderseed'] = dataorderseed

    if algName in OnlineDataAlgSet:
        KwArgs[algName]['nLap'] = KwArgs['OnlineDataPrefs']['nLap']

    if isinstance(dataName, str):
        if os.path.exists(dataName):
            # dataName is a path to many data files on disk
            Data, InitData = loadDataIteratorFromDisk(dataName, ReqArgs,
                                                      KwArgs, dataorderseed)
            DataArgs = UnkArgs
            # Set the short name for this dataset,
            # so that the filepath for results is informative.
            if not hasattr(Data, 'name'):
                try:
                    Data.name = KwArgs['OnlineDataPrefs']['datasetName']
                except KeyError:
                    Data.name = 'UnknownDatasetName'
        else:
            DataArgs = getKwArgsForLoadData(ReqArgs, UnkArgs, KwArgs)
            Data, InitData = loadData(ReqArgs, KwArgs, DataArgs, dataorderseed)
    else:
        Data = dataName
        InitData = dataName
        DataArgs = dict()
        assert isinstance(Data, bnpy.data.DataObj)
        if algName in OnlineDataAlgSet:
            OnlineDataArgs = KwArgs['OnlineDataPrefs']
            OnlineDataArgs['dataorderseed'] = dataorderseed

            DataArgs = getKwArgsForLoadData(Data, UnkArgs)
            OnlineDataArgs.update(DataArgs)  # add custom args
            Data = Data.to_iterator(**OnlineDataArgs)
    if hasattr(Data, 'name'):
        ReqArgs['dataName'] = Data.name
    if doSaveToDisk:
        task_output_path = make_task_output_path(ReqArgs,
                                                 KwArgs,
                                                 taskID=taskid)
        createEmptyOutputPathOnDisk(task_output_path)
        writeArgsToFile(ReqArgs, KwArgs, task_output_path, UnkArgs)
    else:
        task_output_path = None
    KwArgs['OutputPrefs']['task_output_path'] = task_output_path
    jobID = configLoggingToConsoleAndFile(task_output_path, taskid,
                                          doSaveToDisk, doWriteStdOut)

    # Write descriptions to the log
    if taskid == 1 or jobID > 0:
        # Warn user about any unknown keyword arguments
        showWarningForUnknownArgs(UnkArgs, DataArgs)

        Log.info('Dataset Summary:')
        Log.info(Data.get_text_summary())
        Log.info(Data.get_stats_summary())

    # Create and initialize model parameters
    hmodel = make_initialized_model(
        InitData,
        seed=algseed,
        taskid=taskid,
        allocModelName=ReqArgs['allocModelName'],
        obsModelName=ReqArgs['obsModelName'],
        algName=ReqArgs['algName'],
        KwArgs=KwArgs,
        verbose=(taskid == 1 or jobID > 0),
    )

    # Create learning algorithm
    learnAlg = createLearnAlg(Data,
                              hmodel,
                              ReqArgs,
                              KwArgs,
                              algseed=algseed,
                              task_output_path=task_output_path)
    if learnAlg.hasMove('birth'):
        import bnpy.birthmove.BLogger as BirthLogger
        BirthLogger.configure(task_output_path, doSaveToDisk, doWriteStdOut)
    if learnAlg.hasMove('delete'):
        import bnpy.deletemove.DLogger as DeleteLogger
        DeleteLogger.configure(task_output_path, doSaveToDisk, doWriteStdOut)
    if learnAlg.hasMove('merge'):
        import bnpy.mergemove.MLogger as MergeLogger
        MergeLogger.configure(task_output_path, doSaveToDisk, doWriteStdOut)
    if learnAlg.hasMove('shuffle'):
        import bnpy.mergemove.SLogger as SLogger
        SLogger.configure(task_output_path, doSaveToDisk, doWriteStdOut)
    if str(type(hmodel.allocModel)).count('TopicModel'):
        import bnpy.allocmodel.topics.LocalStepLogger as LocalStepLogger
        LocalStepLogger.configure(task_output_path, doSaveToDisk,
                                  doWriteStdOut)

    # Set up logging for how long each step of the alg takes.
    import bnpy.learnalg.ElapsedTimeLogger as ElapsedTimeLogger
    ElapsedTimeLogger.configure(task_output_path, KwArgs['MoveNames'],
                                doSaveToDisk, doWriteStdOut)

    Log.info(
        'Learn Alg: %s | task %2d/%d | alg. seed: %d | data order seed: %d' %
        (algName, taskid, nTask, algseed, dataorderseed))
    Log.info('task_output_path: %s' % (task_output_path))

    # Fit the model to the data!
    RunInfo = learnAlg.fit(hmodel, Data)
    RunInfo['UnkArgs'] = UnkArgs
    RunInfo['KwArgs'] = KwArgs
    RunInfo['ReqArgs'] = ReqArgs
    return hmodel, RunInfo
Example #5
0
def selectCompsForBirthAtCurrentBatch(hmodel=None,
                                      SS=None,
                                      SSbatch=None,
                                      MoveRecordsByUID=dict(),
                                      MovePlans=dict(),
                                      lapFrac=0,
                                      batchID=0,
                                      batchPos=0,
                                      nBatch=1,
                                      isFirstBatch=False,
                                      doPrintLotsOfDetails=True,
                                      **BArgs):
    ''' Select specific comps to target with birth move at current batch.

    Returns
    -------
    MovePlans : dict with updated fields
    * b_targetUIDs : list of ints,
        Each uid in b_targetUIDs will be tried immediately, at current batch.

    MoveRecordsByUID : dict with updated fields
    * [uid]['byBatch'][batchID] : dict with fields
        proposalBatchSize
        proposalTotalSize
    '''
    # Extract num clusters in current model
    K = SS.K
    if K > 25:
        doPrintLotsOfDetails = False
    statusStr = ' lap %7.3f lapCeil %5d batchPos %3d/%d batchID %3d ' % (
        lapFrac, np.ceil(lapFrac), batchPos, nBatch, batchID)
    BLogger.pprint('PLAN at ' + statusStr)

    if BArgs['Kmax'] - SS.K <= 0:
        msg = "Cannot plan any more births." + \
            " Reached upper limit of %d existing comps (--Kmax)." % (
                BArgs['Kmax'])
        BLogger.pprint(msg)
        if 'b_targetUIDs' in MovePlans:
            del MovePlans['b_targetUIDs']
        MovePlans['b_statusMsg'] = msg
        BLogger.pprint('')
        return MovePlans

    if isFirstBatch:
        assert 'b_targetUIDs' not in MovePlans

    if isFirstBatch or 'b_firstbatchUIDs' not in MovePlans:
        MovePlans['b_firstbatchUIDs'] = SSbatch.uids.copy()
        MovePlans['b_CountVec_SeenThisLap'] = np.zeros(K)
    for k, uid in enumerate(MovePlans['b_firstbatchUIDs']):
        MovePlans['b_CountVec_SeenThisLap'][k] += SSbatch.getCountForUID(uid)

    # Short-circuit. Keep retained clusters.
    if lapFrac > 1.0 and BArgs['b_retainAcrossBatchesAfterFirstLap']:
        if not isFirstBatch:
            if 'b_targetUIDs' in MovePlans:
                msg = "%d UIDs retained from proposals earlier this lap." + \
                    " No new proposals at this batch."
                msg = msg % (len(MovePlans['b_targetUIDs']))
                BLogger.pprint(msg)
                if len(MovePlans['b_targetUIDs']) > 0:
                    BLogger.pprint(vec2str(MovePlans['b_targetUIDs']))
            else:
                BLogger.pprint(
                    'No UIDs targeted earlier in lap.' + \
                    ' No new proposals at this batch.')
            return MovePlans

    # Compute sizes for each cluster
    CountVec_b = np.maximum(SSbatch.getCountVec(), 1e-100)
    CountVec_all = np.maximum(SS.getCountVec(), 1e-100)
    atomstr = 'atoms'
    labelstr = 'count_b'

    uidsBusyWithOtherMoves = list()
    uidsTooSmall = list()
    uidsWithFailRecord = list()
    eligible_mask = np.zeros(K, dtype=np.bool8)
    for ii, uid in enumerate(SS.uids):
        if uid not in MoveRecordsByUID:
            MoveRecordsByUID[uid] = defaultdict(int)
        if not isinstance(MoveRecordsByUID[uid]['byBatch'], dict):
            MoveRecordsByUID[uid]['byBatch'] = \
                defaultdict(lambda: defaultdict(int))
        uidRec = MoveRecordsByUID[uid]
        uidRec_b = MoveRecordsByUID[uid]['byBatch'][batchID]

        uidstatusStr = "STATUS uid %5d %s N_b %9.3f N_ttl %9.3f" % (
            uid, statusStr, SSbatch.getCountForUID(uid),
            SS.getCountForUID(uid))
        # Continue to track UIDs that are pre-existing targets
        if 'b_targetUIDs' in MovePlans:
            if uid in MovePlans['b_targetUIDs']:
                BLogger.startUIDSpecificLog(uid)
                BLogger.pprint(uidstatusStr + " CHOSENAGAIN")
                BLogger.stopUIDSpecificLog(uid)
                continue
        # TODO REMOVE DEAD CODE
        if MoveRecordsByUID[uid]['b_tryAgainFutureLap'] > 0:
            msg = "Try targeting uid %d again." % (uid)
            BLogger.pprint(msg)
            del MoveRecordsByUID[uid]['b_tryAgainFutureLap']
            eligible_mask[ii] = 1
            continue

        # Discard uids which are active in another proposal.
        if 'd_targetUIDs' in MovePlans:
            if uid in MovePlans['d_targetUIDs']:
                uidsBusyWithOtherMoves.append(uid)
                BLogger.startUIDSpecificLog(uid)
                BLogger.pprint(uidstatusStr + " BUSY DELETE PROPOSAL")
                BLogger.stopUIDSpecificLog(uid)
                continue
        if 'd_absorbingUIDSet' in MovePlans:
            if uid in MovePlans['d_absorbingUIDSet']:
                uidsBusyWithOtherMoves.append(uid)
                BLogger.startUIDSpecificLog(uid)
                BLogger.pprint(uidstatusStr + " BUSY DELETE PROPOSAL")
                BLogger.stopUIDSpecificLog(uid)
                continue

        if 'm_targetUIDSet' in MovePlans:
            if uid in MovePlans['m_targetUIDSet']:
                uidsBusyWithOtherMoves.append(uid)
                BLogger.startUIDSpecificLog(uid)
                BLogger.pprint(uidstatusStr + " BUSY MERGE PROPOSAL")
                BLogger.stopUIDSpecificLog(uid)
                continue

        # Filter out uids without large presence in current batch
        bigEnough = CountVec_b[ii] >= BArgs['b_minNumAtomsForTargetComp']
        if not bigEnough:
            uidsTooSmall.append((uid, CountVec_b[ii]))
            BLogger.startUIDSpecificLog(uid)
            BLogger.pprint(
                uidstatusStr + " TOO SMALL %.2f < %.2f" %
                (CountVec_b[ii], BArgs['b_minNumAtomsForTargetComp']))
            BLogger.stopUIDSpecificLog(uid)
            continue

        eligibleSuffix = ''
        # Filter out uids we've failed on this particular batch before
        if uidRec_b['nFail'] > 0:
            prevBatchSize = uidRec_b['proposalBatchSize']
            prevTotalSize = uidRec_b['proposalTotalSize']

            curBatchSize = SSbatch.getCountForUID(uid)
            sizePercDiff = np.abs(curBatchSize -
                                  prevBatchSize) / (curBatchSize + 1e-100)
            sizeChangedEnoughToReactivate = sizePercDiff > \
                BArgs['b_minPercChangeInNumAtomsToReactivate']

            curTotalSize = SS.getCountForUID(uid)
            totalPercDiff = np.abs(curTotalSize -
                                   prevTotalSize) / (curTotalSize + 1e-100)
            totalsizeChangedEnoughToReactivate = totalPercDiff > \
                BArgs['b_minPercChangeInNumAtomsToReactivate']

            if sizeChangedEnoughToReactivate:
                eligibleSuffix = \
                    "REACTIVATE BY BATCH SIZE." + \
                    "\n Batch size percDiff %.2f > %.2f" % (
                        sizePercDiff,
                        BArgs['b_minPercChangeInNumAtomsToReactivate']) \
                    + "\n prevBatchSize %9.2f \n curBatchSize %9.2f" % (
                        prevBatchSize, curBatchSize)
                uidRec_b['nFail'] = 0  # Reactivated
            elif totalsizeChangedEnoughToReactivate:
                eligibleSuffix = \
                    "REACTIVATED BY TOTAL SIZE" + \
                    "\n Total size percDiff %.2f > %.2f" % (
                        totalPercDiff,
                        BArgs['b_minPercChangeInNumAtomsToReactivate']) \
                    + "\n prevTotalSize %9.1f \n curTotalSize %9.1f" % (
                        prevTotalSize, curTotalSize)
                uidRec_b['nFail'] = 0  # Reactivated
            else:
                uidsWithFailRecord.append(uid)
                BLogger.startUIDSpecificLog(uid)
                BLogger.pprint(uidstatusStr + " DISQUALIFIED FOR PAST FAILURE")
                BLogger.stopUIDSpecificLog(uid)
                continue
        # If we've made it here, the uid is eligible.
        eligible_mask[ii] = 1
        BLogger.startUIDSpecificLog(uid)
        BLogger.pprint(uidstatusStr + " ELIGIBLE " + eligibleSuffix)
        BLogger.stopUIDSpecificLog(uid)

    # Notify about uids retained
    if 'b_targetUIDs' not in MovePlans:
        MovePlans['b_targetUIDs'] = list()
    msg = "%d/%d UIDs retained from preexisting proposals." % (len(
        MovePlans['b_targetUIDs']), K)
    BLogger.pprint(msg)

    # Log info about busy disqualifications
    nDQ_toobusy = len(uidsBusyWithOtherMoves)
    nDQ_pastfail = len(uidsWithFailRecord)
    msg = "%d/%d UIDs too busy with other moves (merge/delete)." % (
        nDQ_toobusy, K)
    BLogger.pprint(msg)
    # Log info about toosmall disqualification
    nDQ_toosmall = len(uidsTooSmall)
    msg = "%d/%d UIDs too small (too few %s in current batch)." + \
        " Required size >= %d (--b_minNumAtomsForTargetComp)"
    msg = msg % (nDQ_toosmall, K, atomstr, BArgs['b_minNumAtomsForTargetComp'])
    BLogger.pprint(msg, 'debug')
    if nDQ_toosmall > 0 and doPrintLotsOfDetails:
        lineUID = vec2str([u[0] for u in uidsTooSmall])
        lineSize = vec2str([u[1] for u in uidsTooSmall])
        BLogger.pprint(
            [lineUID, lineSize],
            prefix=['%7s' % 'uids', '%7s' % labelstr],
        )
    # Notify about past failure disqualifications to the log
    BLogger.pprint(
        '%d/%d UIDs disqualified for past failures.' % (nDQ_pastfail, K),
        'debug')
    if nDQ_pastfail > 0 and doPrintLotsOfDetails:
        lineUID = vec2str(uidsWithFailRecord)
        BLogger.pprint(lineUID)
    # Store nDQ counts for reporting.
    MovePlans['b_nDQ_toosmall'] = nDQ_toosmall
    MovePlans['b_nDQ_toobusy'] = nDQ_toobusy
    MovePlans['b_nDQ_pastfail'] = nDQ_pastfail
    # Finalize list of eligible UIDs
    eligibleUIDs = SS.uids[eligible_mask]
    BLogger.pprint('%d/%d UIDs eligible for new proposal' %
                   (len(eligibleUIDs), K))
    # EXIT if nothing eligible.
    if len(eligibleUIDs) == 0:
        BLogger.pprint('')
        assert 'b_targetUIDs' in MovePlans
        return MovePlans

    # Record all uids that are eligible!
    # And make vector of how recently they have failed in other attempts
    FailVec = np.inf * np.ones(K)
    for uid in eligibleUIDs:
        uidRec['b_latestEligibleLap'] = lapFrac
        k = SS.uid2k(uid)
        FailVec[k] = MoveRecordsByUID[uid]['b_nFailRecent']

    if doPrintLotsOfDetails:
        lineUID = vec2str(eligibleUIDs)
        lineSize = vec2str(CountVec_all[eligible_mask])
        lineBatchSize = vec2str(CountVec_b[eligible_mask])
        lineFail = vec2str(FailVec[eligible_mask])
        BLogger.pprint(
            [lineUID, lineSize, lineBatchSize, lineFail],
            prefix=[
                '%7s' % 'uids',
                '%7s' % 'cnt_ttl',
                '%7s' % 'cnt_b',
                '%7s' % 'nFail',
            ],
        )

    # Figure out how many new states we can target this round.
    # Prioritize the top comps as ranked by the desired score
    # until we max out the budget of Kmax total comps.
    maxnewK = BArgs['Kmax'] - SS.K
    totalnewK_perEligibleComp = np.minimum(
        np.ceil(CountVec_b[eligible_mask]),
        np.minimum(BArgs['b_Kfresh'], maxnewK))
    # TODO: Worry about retained ids with maxnewK
    sortorder = argsortBigToSmallByTiers(-1 * FailVec[eligible_mask],
                                         CountVec_b[eligible_mask])
    sortedCumulNewK = np.cumsum(totalnewK_perEligibleComp[sortorder])
    nToKeep = np.searchsorted(sortedCumulNewK, maxnewK + 0.0042)
    if nToKeep == 0:
        nToKeep = 1
    keepEligibleIDs = sortorder[:nToKeep]
    newK = np.minimum(sortedCumulNewK[nToKeep - 1], maxnewK)
    chosenUIDs = [eligibleUIDs[s] for s in keepEligibleIDs]

    if nToKeep < len(chosenUIDs):
        BLogger.pprint(
            'Selected %d/%d eligible UIDs to do proposals.' % (
                nToKeep, len(chosenUIDs)) + \
            '\n Could create up to %d new clusters, %d total clusters.' % (
                newK, newK + SS.K) + \
            '\n Total budget allows at most %d clusters (--Kmax).' % (
                BArgs['Kmax']),
            )
    BLogger.pprint('%d/%d UIDs chosen for new proposals (rankby: cnt_b)' %
                   (len(chosenUIDs), len(eligibleUIDs)))
    if doPrintLotsOfDetails:
        lineUID = vec2str(chosenUIDs)
        lineSize = vec2str(CountVec_all[eligible_mask][keepEligibleIDs])
        lineBatchSize = vec2str(CountVec_b[eligible_mask][keepEligibleIDs])
        lineFail = vec2str(FailVec[eligible_mask][keepEligibleIDs])
        BLogger.pprint(
            [lineUID, lineSize, lineBatchSize, lineFail],
            prefix=[
                '%7s' % 'uids',
                '%7s' % 'cnt_ttl',
                '%7s' % 'cnt_b',
                '%7s' % 'fail',
            ],
        )

    for uid in chosenUIDs:
        uidRec = MoveRecordsByUID[uid]
        uidRec['b_proposalBatchID'] = batchID
        uidRec_b = MoveRecordsByUID[uid]['byBatch'][batchID]
        uidRec_b['proposalBatchSize'] = SSbatch.getCountForUID(uid)
        uidRec_b['proposalTotalSize'] = SSbatch.getCountForUID(uid)

    # Aggregate all uids
    MovePlans['b_newlyChosenTargetUIDs'] = chosenUIDs
    MovePlans['b_preExistingTargetUIDs'] = \
        [u for u in MovePlans['b_targetUIDs']]
    MovePlans['b_targetUIDs'].extend(chosenUIDs)

    BLogger.pprint('')
    return MovePlans
Example #6
0
def selectShortListForBirthAtLapStart(hmodel,
                                      SS,
                                      MoveRecordsByUID=dict(),
                                      MovePlans=dict(),
                                      lapFrac=0,
                                      b_minNumAtomsForTargetComp=2,
                                      **BArgs):
    ''' Select list of comps to possibly target with birth during next lap.

    Shortlist uids are guaranteed to never be involved in a merge/delete.
    They are kept aside especially for a birth move, at least in this lap.

    Returns
    -------
    MovePlans : dict with updated fields
    * b_shortlistUIDs : list of ints,
        Each uid in b_shortlistUIDs could be a promising birth target.
        None of these should be touched by deletes or merges in this lap.
    '''
    MovePlans['b_shortlistUIDs'] = list()
    MovePlans['b_nDQ_toosmall'] = 0
    MovePlans['b_nDQ_pastfail'] = 0
    MovePlans['b_nDQ_toobusy'] = 0
    MovePlans['b_roomToGrow'] = 0
    MovePlans['b_maxLenShortlist'] = 0
    if not canBirthHappenAtLap(lapFrac, **BArgs):
        BLogger.pprint('')
        return MovePlans

    K = hmodel.obsModel.K
    KroomToGrow = BArgs['Kmax'] - K
    MovePlans['b_roomToGrow'] = KroomToGrow
    # Each birth adds at least 2 comps.
    # If we have 10 slots left, we can do at most 5 births
    maxLenShortlist = KroomToGrow / 2
    MovePlans['b_maxLenShortlist'] = maxLenShortlist

    # EXIT: early, if no room to grow.
    if KroomToGrow <= 1:
        BLogger.pprint(
            "Cannot shortlist any comps for birth." + \
            " Adding 2 more comps to K=%d exceeds limit of %d (--Kmax)." % (
                K, BArgs['Kmax'])
            )
        BLogger.pprint('')
        return MovePlans
    # Log reasons for shortlist length
    if maxLenShortlist < K:
        msg = " Limiting shortlist to %d possible births this lap." % (
            maxLenShortlist)
        msg += " Any more would cause current K=%d to exceed Kmax=%d" % (
            K, BArgs['Kmax'])
        BLogger.pprint(msg)
    # Handle initialization case: SS is None
    # Must just select all possible comps
    if SS is None:
        shortlistUIDs = np.arange(K).tolist()
        shortlistUIDs = shortlistUIDs[:maxLenShortlist]
        MovePlans['b_shortlistUIDs'] = shortlistUIDs
        BLogger.pprint("No SS provided. Shortlist contains %d possible comps" %
                       (len(shortlistUIDs)))
        BLogger.pprint('')
        return MovePlans
    assert SS.K == K

    CountVec = SS.getCountVec()
    eligible_mask = np.zeros(K, dtype=np.bool8)
    nTooSmall = 0
    nPastFail = 0
    for k, uid in enumerate(SS.uids):
        if uid not in MoveRecordsByUID:
            MoveRecordsByUID[uid] = defaultdict(int)
        tooSmall = CountVec[k] <= b_minNumAtomsForTargetComp
        hasFailRecord = MoveRecordsByUID[uid]['b_nFailRecent'] > 0
        if MoveRecordsByUID[uid]['b_tryAgainFutureLap'] > 0:
            eligible_mask[k] = 1
            MovePlans['b_shortlistUIDs'].append(uid)
        elif (not tooSmall) and (not hasFailRecord):
            eligible_mask[k] = 1
            MovePlans['b_shortlistUIDs'].append(uid)
        elif tooSmall:
            nTooSmall += 1
        else:
            assert hasFailRecord
            nPastFail += 1
    assert len(MovePlans['b_shortlistUIDs']) == np.sum(eligible_mask)
    # Rank the shortlist by size
    if maxLenShortlist < len(MovePlans['b_shortlistUIDs']):
        sortIDs = argsort_bigtosmall_stable(CountVec[eligible_mask])
        sortIDs = sortIDs[:maxLenShortlist]
        MovePlans['b_shortlistUIDs'] = [
            MovePlans['b_shortlistUIDs'][s] for s in sortIDs
        ]
        shortlistCountVec = CountVec[eligible_mask][sortIDs]
    else:
        shortlistCountVec = CountVec[eligible_mask]

    MovePlans['b_nDQ_toosmall'] = nTooSmall
    MovePlans['b_nDQ_pastfail'] = nPastFail
    nShortList = len(MovePlans['b_shortlistUIDs'])
    assert nShortList <= maxLenShortlist
    BLogger.pprint("%d/%d uids selected for short list." % (nShortList, K))
    if nShortList > 0:
        lineUID = vec2str(MovePlans['b_shortlistUIDs'])
        lineSize = vec2str(shortlistCountVec)
        BLogger.pprint(
            [lineUID, lineSize],
            prefix=['%7s' % 'uids', '%7s' % 'size'],
        )
    BLogger.pprint('')
    return MovePlans
Example #7
0
def makeSummaryForBirthProposal(Dslice,
                                curModel,
                                curLPslice,
                                curSSwhole=None,
                                b_creationProposalName='bregmankmeans',
                                targetUID=None,
                                ktarget=None,
                                newUIDs=None,
                                LPkwargs=DefaultLPkwargs,
                                lapFrac=0,
                                batchID=0,
                                seed=0,
                                b_nRefineSteps=3,
                                b_debugOutputDir=None,
                                b_minNumAtomsForNewComp=None,
                                b_doInitCompleteLP=1,
                                b_cleanupWithMerge=1,
                                b_method_initCoordAscent='fromprevious',
                                vocabList=None,
                                **kwargs):
    ''' Create summary that reassigns mass from target to Kfresh new comps.

    TODO support other options than bregman???

    Returns
    -------
    xSSslice : SuffStatBag
        Contains exact summaries for reassignment of target mass.
        * Total mass is equal to mass assigned to ktarget in curLPslice
        * Number of components is Kfresh
    Info : dict
        Contains info for detailed debugging of construction process.
    '''
    # Parse input to decide which cluster to target
    # * targetUID is the unique ID of this cluster
    # * ktarget is its position in the current cluster ordering
    if targetUID is None:
        targetUID = curSSwhole.k2uid(ktarget)
    if ktarget is None:
        ktarget = curSSwhole.uid2k(targetUID)
    # START log for this birth proposal
    BLogger.pprint(
        'Creating proposal for targetUID %s at lap %.2f batchID %d' %
        (targetUID, lapFrac, batchID))
    # Grab vocabList, if available.
    if hasattr(Dslice, 'vocabList') and Dslice.vocabList is not None:
        vocabList = Dslice.vocabList
    # Parse input to decide where to save HTML output
    if b_debugOutputDir == 'None':
        b_debugOutputDir = None
    if b_debugOutputDir:
        BLogger.pprint('HTML output:' + b_debugOutputDir)
        # Create snapshot of current model comps
        plotCompsFromSS(curModel,
                        curSSwhole,
                        os.path.join(b_debugOutputDir, 'OrigComps.png'),
                        vocabList=vocabList,
                        compsToHighlight=[ktarget])

    # Determine exactly how many new states we can make...
    xK = len(newUIDs)
    if xK + curSSwhole.K > kwargs['Kmax']:
        xK = kwargs['Kmax'] - curSSwhole.K
        newUIDs = newUIDs[:xK]
        if xK <= 1:
            errorMsg = 'Cancelled.' + \
                'Adding 2 or more states would exceed budget of %d comps.' % (
                    kwargs['Kmax'])
            BLogger.pprint(errorMsg)
            BLogger.pprint('')
            return None, dict(errorMsg=errorMsg)
    # Create suff stats for some new states
    xInitSStarget, Info = initSS_BregmanDiv(
        Dslice,
        curModel,
        curLPslice,
        K=xK,
        ktarget=ktarget,
        lapFrac=lapFrac,
        seed=seed + int(1000 * lapFrac),
        logFunc=BLogger.pprint,
        NiterForBregmanKMeans=kwargs['b_NiterForBregmanKMeans'],
        **kwargs)
    # EXIT EARLY: if proposal initialization fails (not enough data).
    if xInitSStarget is None:
        BLogger.pprint('Proposal initialization FAILED. ' + \
                       Info['errorMsg'])
        BLogger.pprint('')
        return None, Info

    # If here, we have a valid set of initial stats.
    xInitSStarget.setUIDs(newUIDs[:xInitSStarget.K])
    if b_doInitCompleteLP:
        # Create valid whole-dataset clustering from hard init
        xInitSSslice, tempInfo = makeExpansionSSFromZ(
            Dslice=Dslice,
            curModel=curModel,
            curLPslice=curLPslice,
            ktarget=ktarget,
            xInitSS=xInitSStarget,
            atomType=Info['atomType'],
            targetZ=Info['targetZ'],
            chosenDataIDs=Info['chosenDataIDs'],
            **kwargs)
        Info.update(tempInfo)

        xSSslice = xInitSSslice
    else:
        xSSslice = xInitSStarget

    if b_debugOutputDir:
        plotCompsFromSS(curModel,
                        xSSslice,
                        os.path.join(b_debugOutputDir, 'NewComps_Init.png'),
                        vocabList=vocabList)

        # Determine current model objective score
        curModelFWD = curModel.copy()
        curModelFWD.update_global_params(SS=curSSwhole)
        curLdict = curModelFWD.calc_evidence(SS=curSSwhole, todict=1)
        # Track proposal ELBOs as refinement improves things
        propLdictList = list()
        # Create initial proposal
        if b_doInitCompleteLP:
            propSS = curSSwhole.copy()
            propSS.transferMassFromExistingToExpansion(uid=targetUID,
                                                       xSS=xSSslice)
            # Verify quality
            assert np.allclose(propSS.getCountVec().sum(),
                               curSSwhole.getCountVec().sum())
            propModel = curModel.copy()
            propModel.update_global_params(propSS)
            propLdict = propModel.calc_evidence(SS=propSS, todict=1)
            BLogger.pprint(
                "init %d/%d  gainL % .3e  propL % .3e  curL % .3e" %
                (0, b_nRefineSteps, propLdict['Ltotal'] - curLdict['Ltotal'],
                 propLdict['Ltotal'], curLdict['Ltotal']))
            propLdictList.append(propLdict)

        docUsageByUID = dict()
        if curModel.getAllocModelName().count('HDP'):
            for k, uid in enumerate(xInitSStarget.uids):
                if 'targetZ' in Info:
                    if Info['atomType'].count('doc'):
                        initDocUsage_uid = np.sum(Info['targetZ'] == k)
                    else:
                        initDocUsage_uid = 0.0
                        for d in range(Dslice.nDoc):
                            start = Dslice.doc_range[d]
                            stop = Dslice.doc_range[d + 1]
                            initDocUsage_uid += np.any(
                                Info['targetZ'][start:stop] == k)
                else:
                    initDocUsage_uid = 0.0
                docUsageByUID[uid] = [initDocUsage_uid]

    # Create initial observation model
    xObsModel = curModel.obsModel.copy()

    if b_method_initCoordAscent == 'fromprevious' and 'xLPslice' in Info:
        xInitLPslice = Info['xLPslice']
    else:
        xInitLPslice = None

    # Make a function to pretty-print counts as we refine the initialization
    pprintCountVec = BLogger.makeFunctionToPrettyPrintCounts(xSSslice)
    BLogger.pprint("   " + vec2str(xInitSStarget.uids))
    pprintCountVec(xSSslice)

    # Log messages to describe the initialization.
    BLogger.pprint(' Running %d refinement iterations (--b_nRefineSteps)' %
                   (b_nRefineSteps))
    prevCountVec = xSSslice.getCountVec()
    didConvEarly = False
    convstep = 100 + b_nRefineSteps
    # Run several refinement steps.
    # Each step does a restricted local step to improve
    # the proposed cluster assignments.
    for rstep in range(b_nRefineSteps):
        # Update xObsModel
        xObsModel.update_global_params(xSSslice)

        # Restricted local step!
        # * xInitSS : specifies obs-model stats used for initialization
        xSSslice, refineInfo = summarizeRestrictedLocalStep(
            Dslice=Dslice,
            curModel=curModel,
            curLPslice=curLPslice,
            curSSwhole=curSSwhole,
            ktarget=ktarget,
            xUIDs=xSSslice.uids,
            xInitSS=xSSslice,
            xObsModel=xObsModel,
            xInitLPslice=xInitLPslice,
            LPkwargs=LPkwargs,
            nUpdateSteps=1,
            **kwargs)
        Info.update(refineInfo)
        # Get most recent xLPslice for initialization
        if b_method_initCoordAscent == 'fromprevious' and 'xLPslice' in Info:
            xInitLPslice = Info['xLPslice']
        # On first step, show diagnostics for new states
        if rstep == 0:
            targetPi = refineInfo['emptyPi'] + refineInfo['xPiVec'].sum()
            BLogger.pprint(
                " target prob redistributed by policy %s (--b_method_xPi)" %
                (kwargs['b_method_xPi']))
            msg = " pi[ktarget] before %.4f  after %.4f." % (
                targetPi, refineInfo['emptyPi'])
            BLogger.pprint(msg)
            BLogger.pprint(" pi[new comps]: "  + \
                vec2str(
                    refineInfo['xPiVec'],
                    width=6, minVal=0.0001))
            logLPConvergenceDiagnostics(refineInfo,
                                        rstep=rstep,
                                        b_nRefineSteps=b_nRefineSteps)
            BLogger.pprint("   " + vec2str(xInitSStarget.uids))
        # Show diagnostic counts in each fresh state
        pprintCountVec(xSSslice)
        # Write HTML debug info
        if b_debugOutputDir:
            plotCompsFromSS(curModel,
                            xSSslice,
                            os.path.join(b_debugOutputDir,
                                         'NewComps_Step%d.png' % (rstep + 1)),
                            vocabList=vocabList)
            propSS = curSSwhole.copy()
            propSS.transferMassFromExistingToExpansion(uid=targetUID,
                                                       xSS=xSSslice)
            # Reordering only lifts score by small amount. Not worth it.
            # propSS.reorderComps(np.argsort(-1 * propSS.getCountVec()))
            propModel = curModel.copy()
            propModel.update_global_params(propSS)
            propLdict = propModel.calc_evidence(SS=propSS, todict=1)

            propSSsubset = xSSslice
            tmpModel = curModelFWD
            tmpModel.obsModel.update_global_params(propSSsubset)
            propLdata_subset = tmpModel.obsModel.calcELBO_Memoized(
                propSSsubset)

            curSSsubset = xSSslice.copy(includeELBOTerms=0)
            while curSSsubset.K > 1:
                curSSsubset.mergeComps(0, 1)
            tmpModel.obsModel.update_global_params(curSSsubset)
            curLdata_subset = tmpModel.obsModel.calcELBO_Memoized(curSSsubset)
            gainLdata_subset = propLdata_subset - curLdata_subset
            msg = \
                "step %d/%d  gainL % .3e  propL % .3e  curL % .3e" % (
                    rstep+1, b_nRefineSteps,
                    propLdict['Ltotal'] - curLdict['Ltotal'],
                    propLdict['Ltotal'],
                    curLdict['Ltotal'])
            msg += "  gainLdata_subset % .3e" % (gainLdata_subset)
            BLogger.pprint(msg)
            propLdictList.append(propLdict)
            if curModel.getAllocModelName().count('HDP'):
                docUsageVec = xSSslice.getSelectionTerm('DocUsageCount')
                for k, uid in enumerate(xSSslice.uids):
                    docUsageByUID[uid].append(docUsageVec[k])
        # If converged early and did the final refinement step
        if didConvEarly and rstep > convstep:
            break
        # Cleanup by deleting small clusters
        if rstep < b_nRefineSteps - 1:
            if rstep == b_nRefineSteps - 2 or didConvEarly:
                # After all but last step,
                # delete small (but not empty) comps
                minNumAtomsToStay = b_minNumAtomsForNewComp
            else:
                # Always remove empty clusters. They waste our time.
                minNumAtomsToStay = np.minimum(1, b_minNumAtomsForNewComp)
            xSSslice, xInitLPslice = cleanupDeleteSmallClusters(
                xSSslice,
                minNumAtomsToStay,
                xInitLPslice=xInitLPslice,
                pprintCountVec=pprintCountVec)
        # Decide if we have converged early
        if rstep < b_nRefineSteps - 2 and prevCountVec.size == xSSslice.K:
            if np.allclose(xSSslice.getCountVec(), prevCountVec, atol=0.5):
                # Converged. Jump directly to the merge phase!
                didConvEarly = True
                convstep = rstep
        # Cleanup by merging clusters
        if b_cleanupWithMerge and \
                (rstep == b_nRefineSteps - 2 or didConvEarly):
            # Only cleanup on second-to-last pass, or if converged early
            Info['mergestep'] = rstep + 1
            xSSslice, xInitLPslice = cleanupMergeClusters(
                xSSslice,
                curModel,
                obsSSkeys=list(xInitSStarget._Fields._FieldDims.keys()),
                vocabList=vocabList,
                pprintCountVec=pprintCountVec,
                xInitLPslice=xInitLPslice,
                b_debugOutputDir=b_debugOutputDir,
                **kwargs)

        prevCountVec = xSSslice.getCountVec().copy()

    Info['Kfinal'] = xSSslice.K
    if b_debugOutputDir:
        savefilename = os.path.join(b_debugOutputDir, 'ProposalTrace_ELBO.png')
        plotELBOtermsForProposal(curLdict,
                                 propLdictList,
                                 savefilename=savefilename)
        if curModel.getAllocModelName().count('HDP'):
            savefilename = os.path.join(b_debugOutputDir,
                                        'ProposalTrace_DocUsage.png')
            plotDocUsageForProposal(docUsageByUID, savefilename=savefilename)

    # EXIT EARLY: error if we didn't create enough "big-enough" states.
    nnzCount = np.sum(xSSslice.getCountVec() >= b_minNumAtomsForNewComp)
    if nnzCount < 2:
        Info['errorMsg'] = \
            "Could not create at least two comps" + \
            " with mass >= %.1f (--%s)" % (
                b_minNumAtomsForNewComp, 'b_minNumAtomsForNewComp')
        BLogger.pprint('Proposal build phase FAILED. ' + Info['errorMsg'])
        BLogger.pprint('')  # Blank line
        return None, Info

    # If here, we have a valid proposal.
    # Need to verify mass conservation
    if hasattr(Dslice, 'word_count') and \
            curModel.obsModel.DataAtomType.count('word') and \
            curModel.getObsModelName().count('Mult'):
        origMass = np.inner(Dslice.word_count, curLPslice['resp'][:, ktarget])
    else:
        if 'resp' in curLPslice:
            origMass = curLPslice['resp'][:, ktarget].sum()
        else:
            origMass = curLPslice['spR'][:, ktarget].sum()
    newMass = xSSslice.getCountVec().sum()
    assert np.allclose(newMass, origMass, atol=1e-6, rtol=0)
    BLogger.pprint('Proposal build phase DONE.' + \
        ' Created %d candidate clusters.' % (Info['Kfinal']))
    BLogger.pprint('')  # Blank line
    return xSSslice, Info
Example #8
0
def makeSummaryForExistingBirthProposal(
        Dslice,
        curModel,
        curLPslice,
        curSSwhole=None,
        targetUID=None,
        ktarget=None,
        LPkwargs=DefaultLPkwargs,
        lapFrac=0,
        batchID=0,
        b_nRefineSteps=3,
        b_debugOutputDir=None,
        b_method_initCoordAscent='fromprevious',
        vocabList=None,
        **kwargs):
    ''' Create summary that reassigns mass from target given set of comps

    Given set of comps is a fixed proposal from a previously-seen batch.

    Returns
    -------
    xSSslice : SuffStatBag
        Contains exact summaries for reassignment of target mass.
        * Total mass is equal to mass assigned to ktarget in curLPslice
        * Number of components is Kfresh
    Info : dict
        Contains info for detailed debugging of construction process.
    '''
    if targetUID is None:
        targetUID = curSSwhole.uids(ktarget)
    if ktarget is None:
        ktarget = curSSwhole.uid2k(targetUID)
    # START log for this birth proposal
    BLogger.pprint(
        'Extending previous birth for targetUID %s at lap %.2f batch %d' %
        (targetUID, lapFrac, batchID))
    # Grab vocabList, if available.
    if hasattr(Dslice, 'vocabList') and Dslice.vocabList is not None:
        vocabList = Dslice.vocabList
    # Parse input to decide where to save HTML output
    if b_debugOutputDir == 'None':
        b_debugOutputDir = None
    if b_debugOutputDir:
        BLogger.pprint('HTML output:' + b_debugOutputDir)
        # Create snapshot of current model comps
        plotCompsFromSS(curModel,
                        curSSwhole,
                        os.path.join(b_debugOutputDir, 'OrigComps.png'),
                        vocabList=vocabList,
                        compsToHighlight=[ktarget])

    assert targetUID in curSSwhole.propXSS
    xinitSS = curSSwhole.propXSS[targetUID]
    xK = xinitSS.K
    if xK + curSSwhole.K > kwargs['Kmax']:
        errorMsg = 'Cancelled.' + \
            'Adding 2 or more states would exceed budget of %d comps.' % (
                kwargs['Kmax'])
        BLogger.pprint(errorMsg)
        BLogger.pprint('')
        return None, dict(errorMsg=errorMsg)

    # Log messages to describe the initialization.
    # Make a function to pretty-print counts as we refine the initialization
    pprintCountVec = BLogger.makeFunctionToPrettyPrintCounts(xinitSS)
    BLogger.pprint('  Using previous proposal with %d clusters %s.' %
                   (xinitSS.K, '(--b_Kfresh=%d)' % kwargs['b_Kfresh']))
    BLogger.pprint("  Initial uid/counts from previous proposal:")
    BLogger.pprint('   ' + vec2str(xinitSS.uids))
    pprintCountVec(xinitSS)
    BLogger.pprint('  Running %d refinement iterations (--b_nRefineSteps)' %
                   (b_nRefineSteps))

    xSSinitPlusSlice = xinitSS.copy()
    if b_debugOutputDir:
        plotCompsFromSS(curModel,
                        xinitSS,
                        os.path.join(b_debugOutputDir, 'NewComps_Init.png'),
                        vocabList=vocabList)

        # Determine current model objective score
        curModelFWD = curModel.copy()
        curModelFWD.update_global_params(SS=curSSwhole)
        curLdict = curModelFWD.calc_evidence(SS=curSSwhole, todict=1)
        # Track proposal ELBOs as refinement improves things
        propLdictList = list()
        docUsageByUID = dict()
        if curModel.getAllocModelName().count('HDP'):
            for k, uid in enumerate(xinitSS.uids):
                initDocUsage_uid = 0.0
                docUsageByUID[uid] = [initDocUsage_uid]

    # Create initial observation model
    xObsModel = curModel.obsModel.copy()
    xInitLPslice = None
    Info = dict()
    # Run several refinement steps.
    # Each step does a restricted local step to improve
    # the proposed cluster assignments.
    nRefineSteps = np.maximum(1, b_nRefineSteps)
    for rstep in range(nRefineSteps):
        xObsModel.update_global_params(xSSinitPlusSlice)

        # Restricted local step!
        # * xInitSS : specifies obs-model stats used for initialization
        xSSslice, refineInfo = summarizeRestrictedLocalStep(
            Dslice=Dslice,
            curModel=curModel,
            curLPslice=curLPslice,
            curSSwhole=curSSwhole,
            ktarget=ktarget,
            xUIDs=xSSinitPlusSlice.uids,
            xObsModel=xObsModel,
            xInitSS=xSSinitPlusSlice,  # first time in loop <= xinitSS
            xInitLPslice=xInitLPslice,
            LPkwargs=LPkwargs,
            **kwargs)

        xSSinitPlusSlice += xSSslice
        if rstep >= 1:
            xSSinitPlusSlice -= prevSSslice
        prevSSslice = xSSslice

        Info.update(refineInfo)
        # Show diagnostics for new states
        pprintCountVec(xSSslice)
        logLPConvergenceDiagnostics(refineInfo,
                                    rstep=rstep,
                                    b_nRefineSteps=b_nRefineSteps)
        # Get most recent xLPslice for initialization
        if b_method_initCoordAscent == 'fromprevious' and 'xLPslice' in Info:
            xInitLPslice = Info['xLPslice']
        if b_debugOutputDir:
            plotCompsFromSS(curModel,
                            xSSslice,
                            os.path.join(b_debugOutputDir,
                                         'NewComps_Step%d.png' % (rstep + 1)),
                            vocabList=vocabList)
            propSS = curSSwhole.copy()
            propSS.transferMassFromExistingToExpansion(uid=targetUID,
                                                       xSS=xSSslice)
            propModel = curModel.copy()
            propModel.update_global_params(propSS)
            propLdict = propModel.calc_evidence(SS=propSS, todict=1)
            BLogger.pprint(
                "step %d/%d  gainL % .3e  propL % .3e  curL % .3e" %
                (rstep + 1, b_nRefineSteps, propLdict['Ltotal'] -
                 curLdict['Ltotal'], propLdict['Ltotal'], curLdict['Ltotal']))
            propLdictList.append(propLdict)
            if curModel.getAllocModelName().count('HDP'):
                docUsageVec = xSSslice.getSelectionTerm('DocUsageCount')
                for k, uid in enumerate(xSSslice.uids):
                    docUsageByUID[uid].append(docUsageVec[k])

    Info['Kfinal'] = xSSslice.K
    if b_debugOutputDir:
        savefilename = os.path.join(b_debugOutputDir, 'ProposalTrace_ELBO.png')
        plotELBOtermsForProposal(curLdict,
                                 propLdictList,
                                 savefilename=savefilename)
        if curModel.getAllocModelName().count('HDP'):
            savefilename = os.path.join(b_debugOutputDir,
                                        'ProposalTrace_DocUsage.png')
            plotDocUsageForProposal(docUsageByUID, savefilename=savefilename)

    # If here, we have a valid proposal.
    # Need to verify mass conservation
    if hasattr(Dslice, 'word_count') and \
            curModel.obsModel.DataAtomType.count('word') and \
            curModel.getObsModelName().count('Mult'):
        origMass = np.inner(Dslice.word_count, curLPslice['resp'][:, ktarget])
    else:
        if 'resp' in curLPslice:
            origMass = curLPslice['resp'][:, ktarget].sum()
        else:
            origMass = curLPslice['spR'][:, ktarget].sum()
    newMass = xSSslice.getCountVec().sum()
    assert np.allclose(newMass, origMass, atol=1e-6, rtol=0)
    BLogger.pprint('Proposal extension DONE. %d candidate clusters.' %
                   (Info['Kfinal']))
    BLogger.pprint('')
    return xSSslice, Info
def makeSummariesForManyBirthProposals(Dslice=None,
                                       curModel=None,
                                       curLPslice=None,
                                       curSSwhole=None,
                                       curSSslice=None,
                                       LPkwargs=None,
                                       newUIDs=list(),
                                       b_targetUIDs=None,
                                       xSSProposalsByUID=None,
                                       MovePlans=dict(),
                                       MoveRecordsByUID=dict(),
                                       taskoutpath='/tmp/',
                                       lapFrac=0.0,
                                       batchID=0,
                                       batchPos=0,
                                       nBatch=0,
                                       **BArgs):
    '''

    Args
    ----
    BArgs : dict of all kwarg options for birth moves

    Returns
    -------
    xSSProposalsByUID : dict
    MovePlans : dict
        Tracks aggregate performance across all birth proposals.
    MoveRecordsByUID : dict
        each key is a uid. Tracks performance for each uid.
    '''
    if b_targetUIDs is None:
        b_targetUIDs = MovePlans['b_targetUIDs']
    if len(b_targetUIDs) > 0:
        BLogger.pprint('CREATING birth proposals at lap %.2f batch %d' %
                       (lapFrac, batchID))
    if xSSProposalsByUID is None:
        xSSProposalsByUID = dict()
    failedUIDs = list()
    # Loop thru copy of the target comp UID list
    # So that we can remove elements from it within the loop
    for ii, targetUID in enumerate(b_targetUIDs):

        if targetUID in xSSProposalsByUID:
            raise ValueError("Already have a proposal for this UID")

        Kfresh = BArgs['b_Kfresh']
        newUIDs_ii = newUIDs[(ii * Kfresh):((ii + 1) * Kfresh)]
        if len(newUIDs_ii) < 2:
            raise ValueError("Cannot make proposal with less than 2 new UIDs")
        xSSslice, Info = makeSummaryForBirthProposal_HTMLWrapper(
            Dslice,
            curModel,
            curLPslice,
            curSSwhole=curSSwhole,
            targetUID=targetUID,
            newUIDs=newUIDs_ii,
            LPkwargs=LPkwargs,
            lapFrac=lapFrac,
            batchID=batchID,
            **BArgs)
        if xSSslice is not None:
            # Proposal successful, with at least 2 non-empty clusters.
            # Move on to the evaluation stage!
            xSSProposalsByUID[targetUID] = xSSslice
        else:
            # Failure. Expansion did not create good proposal.
            failedUIDs.append(targetUID)
            MovePlans['b_nTrial'] += 1
            MovePlans['b_nFailedProp'] += 1
            if targetUID not in MoveRecordsByUID:
                MoveRecordsByUID[targetUID] = defaultdict(int)
            uidRec = MoveRecordsByUID[targetUID]
            ktarget = curSSwhole.uid2k(targetUID)
            uidRec['b_nTrial'] += 1
            uidRec['b_nFail'] += 1
            uidRec['b_nFailRecent'] += 1
            uidRec['b_nSuccessRecent'] = 0
            uidRec['b_latestLap'] = lapFrac
            uidRec['b_latestCount'] = curSSwhole.getCountVec()[ktarget]
            # Update batch-specific records for this uid
            uidRec_b = uidRec['byBatch'][uidRec['b_proposalBatchID']]
            uidRec_b['nFail'] += 1

    for failUID in failedUIDs:
        b_targetUIDs.remove(failUID)
    MovePlans['b_targetUIDs'] = b_targetUIDs
    return xSSProposalsByUID, MovePlans, MoveRecordsByUID