示例#1
0
def makeSummaryForBirthProposal(Dslice,
                                curModel,
                                curLPslice,
                                curSSwhole=None,
                                b_creationProposalName='bregmankmeans',
                                targetUID=None,
                                ktarget=None,
                                newUIDs=None,
                                LPkwargs=DefaultLPkwargs,
                                lapFrac=0,
                                batchID=0,
                                seed=0,
                                b_nRefineSteps=3,
                                b_debugOutputDir=None,
                                b_minNumAtomsForNewComp=None,
                                b_doInitCompleteLP=1,
                                b_cleanupWithMerge=1,
                                b_method_initCoordAscent='fromprevious',
                                vocabList=None,
                                **kwargs):
    ''' Create summary that reassigns mass from target to Kfresh new comps.

    TODO support other options than bregman???

    Returns
    -------
    xSSslice : SuffStatBag
        Contains exact summaries for reassignment of target mass.
        * Total mass is equal to mass assigned to ktarget in curLPslice
        * Number of components is Kfresh
    Info : dict
        Contains info for detailed debugging of construction process.
    '''
    # Parse input to decide which cluster to target
    # * targetUID is the unique ID of this cluster
    # * ktarget is its position in the current cluster ordering
    if targetUID is None:
        targetUID = curSSwhole.k2uid(ktarget)
    if ktarget is None:
        ktarget = curSSwhole.uid2k(targetUID)
    # START log for this birth proposal
    BLogger.pprint(
        'Creating proposal for targetUID %s at lap %.2f batchID %d' %
        (targetUID, lapFrac, batchID))
    # Grab vocabList, if available.
    if hasattr(Dslice, 'vocabList') and Dslice.vocabList is not None:
        vocabList = Dslice.vocabList
    # Parse input to decide where to save HTML output
    if b_debugOutputDir == 'None':
        b_debugOutputDir = None
    if b_debugOutputDir:
        BLogger.pprint('HTML output:' + b_debugOutputDir)
        # Create snapshot of current model comps
        plotCompsFromSS(curModel,
                        curSSwhole,
                        os.path.join(b_debugOutputDir, 'OrigComps.png'),
                        vocabList=vocabList,
                        compsToHighlight=[ktarget])

    # Determine exactly how many new states we can make...
    xK = len(newUIDs)
    if xK + curSSwhole.K > kwargs['Kmax']:
        xK = kwargs['Kmax'] - curSSwhole.K
        newUIDs = newUIDs[:xK]
        if xK <= 1:
            errorMsg = 'Cancelled.' + \
                'Adding 2 or more states would exceed budget of %d comps.' % (
                    kwargs['Kmax'])
            BLogger.pprint(errorMsg)
            BLogger.pprint('')
            return None, dict(errorMsg=errorMsg)
    # Create suff stats for some new states
    xInitSStarget, Info = initSS_BregmanDiv(
        Dslice,
        curModel,
        curLPslice,
        K=xK,
        ktarget=ktarget,
        lapFrac=lapFrac,
        seed=seed + int(1000 * lapFrac),
        logFunc=BLogger.pprint,
        NiterForBregmanKMeans=kwargs['b_NiterForBregmanKMeans'],
        **kwargs)
    # EXIT EARLY: if proposal initialization fails (not enough data).
    if xInitSStarget is None:
        BLogger.pprint('Proposal initialization FAILED. ' + \
                       Info['errorMsg'])
        BLogger.pprint('')
        return None, Info

    # If here, we have a valid set of initial stats.
    xInitSStarget.setUIDs(newUIDs[:xInitSStarget.K])
    if b_doInitCompleteLP:
        # Create valid whole-dataset clustering from hard init
        xInitSSslice, tempInfo = makeExpansionSSFromZ(
            Dslice=Dslice,
            curModel=curModel,
            curLPslice=curLPslice,
            ktarget=ktarget,
            xInitSS=xInitSStarget,
            atomType=Info['atomType'],
            targetZ=Info['targetZ'],
            chosenDataIDs=Info['chosenDataIDs'],
            **kwargs)
        Info.update(tempInfo)

        xSSslice = xInitSSslice
    else:
        xSSslice = xInitSStarget

    if b_debugOutputDir:
        plotCompsFromSS(curModel,
                        xSSslice,
                        os.path.join(b_debugOutputDir, 'NewComps_Init.png'),
                        vocabList=vocabList)

        # Determine current model objective score
        curModelFWD = curModel.copy()
        curModelFWD.update_global_params(SS=curSSwhole)
        curLdict = curModelFWD.calc_evidence(SS=curSSwhole, todict=1)
        # Track proposal ELBOs as refinement improves things
        propLdictList = list()
        # Create initial proposal
        if b_doInitCompleteLP:
            propSS = curSSwhole.copy()
            propSS.transferMassFromExistingToExpansion(uid=targetUID,
                                                       xSS=xSSslice)
            # Verify quality
            assert np.allclose(propSS.getCountVec().sum(),
                               curSSwhole.getCountVec().sum())
            propModel = curModel.copy()
            propModel.update_global_params(propSS)
            propLdict = propModel.calc_evidence(SS=propSS, todict=1)
            BLogger.pprint(
                "init %d/%d  gainL % .3e  propL % .3e  curL % .3e" %
                (0, b_nRefineSteps, propLdict['Ltotal'] - curLdict['Ltotal'],
                 propLdict['Ltotal'], curLdict['Ltotal']))
            propLdictList.append(propLdict)

        docUsageByUID = dict()
        if curModel.getAllocModelName().count('HDP'):
            for k, uid in enumerate(xInitSStarget.uids):
                if 'targetZ' in Info:
                    if Info['atomType'].count('doc'):
                        initDocUsage_uid = np.sum(Info['targetZ'] == k)
                    else:
                        initDocUsage_uid = 0.0
                        for d in xrange(Dslice.nDoc):
                            start = Dslice.doc_range[d]
                            stop = Dslice.doc_range[d + 1]
                            initDocUsage_uid += np.any(
                                Info['targetZ'][start:stop] == k)
                else:
                    initDocUsage_uid = 0.0
                docUsageByUID[uid] = [initDocUsage_uid]

    # Create initial observation model
    xObsModel = curModel.obsModel.copy()

    if b_method_initCoordAscent == 'fromprevious' and 'xLPslice' in Info:
        xInitLPslice = Info['xLPslice']
    else:
        xInitLPslice = None

    # Make a function to pretty-print counts as we refine the initialization
    pprintCountVec = BLogger.makeFunctionToPrettyPrintCounts(xSSslice)
    BLogger.pprint("   " + vec2str(xInitSStarget.uids))
    pprintCountVec(xSSslice)

    # Log messages to describe the initialization.
    BLogger.pprint(' Running %d refinement iterations (--b_nRefineSteps)' %
                   (b_nRefineSteps))
    prevCountVec = xSSslice.getCountVec()
    didConvEarly = False
    convstep = 100 + b_nRefineSteps
    # Run several refinement steps.
    # Each step does a restricted local step to improve
    # the proposed cluster assignments.
    for rstep in range(b_nRefineSteps):
        # Update xObsModel
        xObsModel.update_global_params(xSSslice)

        # Restricted local step!
        # * xInitSS : specifies obs-model stats used for initialization
        xSSslice, refineInfo = summarizeRestrictedLocalStep(
            Dslice=Dslice,
            curModel=curModel,
            curLPslice=curLPslice,
            curSSwhole=curSSwhole,
            ktarget=ktarget,
            xUIDs=xSSslice.uids,
            xInitSS=xSSslice,
            xObsModel=xObsModel,
            xInitLPslice=xInitLPslice,
            LPkwargs=LPkwargs,
            nUpdateSteps=1,
            **kwargs)
        Info.update(refineInfo)
        # Get most recent xLPslice for initialization
        if b_method_initCoordAscent == 'fromprevious' and 'xLPslice' in Info:
            xInitLPslice = Info['xLPslice']
        # On first step, show diagnostics for new states
        if rstep == 0:
            targetPi = refineInfo['emptyPi'] + refineInfo['xPiVec'].sum()
            BLogger.pprint(
                " target prob redistributed by policy %s (--b_method_xPi)" %
                (kwargs['b_method_xPi']))
            msg = " pi[ktarget] before %.4f  after %.4f." % (
                targetPi, refineInfo['emptyPi'])
            BLogger.pprint(msg)
            BLogger.pprint(" pi[new comps]: "  + \
                vec2str(
                    refineInfo['xPiVec'],
                    width=6, minVal=0.0001))
            logLPConvergenceDiagnostics(refineInfo,
                                        rstep=rstep,
                                        b_nRefineSteps=b_nRefineSteps)
            BLogger.pprint("   " + vec2str(xInitSStarget.uids))
        # Show diagnostic counts in each fresh state
        pprintCountVec(xSSslice)
        # Write HTML debug info
        if b_debugOutputDir:
            plotCompsFromSS(curModel,
                            xSSslice,
                            os.path.join(b_debugOutputDir,
                                         'NewComps_Step%d.png' % (rstep + 1)),
                            vocabList=vocabList)
            propSS = curSSwhole.copy()
            propSS.transferMassFromExistingToExpansion(uid=targetUID,
                                                       xSS=xSSslice)
            # Reordering only lifts score by small amount. Not worth it.
            # propSS.reorderComps(np.argsort(-1 * propSS.getCountVec()))
            propModel = curModel.copy()
            propModel.update_global_params(propSS)
            propLdict = propModel.calc_evidence(SS=propSS, todict=1)

            propSSsubset = xSSslice
            tmpModel = curModelFWD
            tmpModel.obsModel.update_global_params(propSSsubset)
            propLdata_subset = tmpModel.obsModel.calcELBO_Memoized(
                propSSsubset)

            curSSsubset = xSSslice.copy(includeELBOTerms=0)
            while curSSsubset.K > 1:
                curSSsubset.mergeComps(0, 1)
            tmpModel.obsModel.update_global_params(curSSsubset)
            curLdata_subset = tmpModel.obsModel.calcELBO_Memoized(curSSsubset)
            gainLdata_subset = propLdata_subset - curLdata_subset
            msg = \
                "step %d/%d  gainL % .3e  propL % .3e  curL % .3e" % (
                    rstep+1, b_nRefineSteps,
                    propLdict['Ltotal'] - curLdict['Ltotal'],
                    propLdict['Ltotal'],
                    curLdict['Ltotal'])
            msg += "  gainLdata_subset % .3e" % (gainLdata_subset)
            BLogger.pprint(msg)
            propLdictList.append(propLdict)
            if curModel.getAllocModelName().count('HDP'):
                docUsageVec = xSSslice.getSelectionTerm('DocUsageCount')
                for k, uid in enumerate(xSSslice.uids):
                    docUsageByUID[uid].append(docUsageVec[k])
        # If converged early and did the final refinement step
        if didConvEarly and rstep > convstep:
            break
        # Cleanup by deleting small clusters
        if rstep < b_nRefineSteps - 1:
            if rstep == b_nRefineSteps - 2 or didConvEarly:
                # After all but last step,
                # delete small (but not empty) comps
                minNumAtomsToStay = b_minNumAtomsForNewComp
            else:
                # Always remove empty clusters. They waste our time.
                minNumAtomsToStay = np.minimum(1, b_minNumAtomsForNewComp)
            xSSslice, xInitLPslice = cleanupDeleteSmallClusters(
                xSSslice,
                minNumAtomsToStay,
                xInitLPslice=xInitLPslice,
                pprintCountVec=pprintCountVec)
        # Decide if we have converged early
        if rstep < b_nRefineSteps - 2 and prevCountVec.size == xSSslice.K:
            if np.allclose(xSSslice.getCountVec(), prevCountVec, atol=0.5):
                # Converged. Jump directly to the merge phase!
                didConvEarly = True
                convstep = rstep
        # Cleanup by merging clusters
        if b_cleanupWithMerge and \
                (rstep == b_nRefineSteps - 2 or didConvEarly):
            # Only cleanup on second-to-last pass, or if converged early
            Info['mergestep'] = rstep + 1
            xSSslice, xInitLPslice = cleanupMergeClusters(
                xSSslice,
                curModel,
                obsSSkeys=xInitSStarget._Fields._FieldDims.keys(),
                vocabList=vocabList,
                pprintCountVec=pprintCountVec,
                xInitLPslice=xInitLPslice,
                b_debugOutputDir=b_debugOutputDir,
                **kwargs)

        prevCountVec = xSSslice.getCountVec().copy()

    Info['Kfinal'] = xSSslice.K
    if b_debugOutputDir:
        savefilename = os.path.join(b_debugOutputDir, 'ProposalTrace_ELBO.png')
        plotELBOtermsForProposal(curLdict,
                                 propLdictList,
                                 savefilename=savefilename)
        if curModel.getAllocModelName().count('HDP'):
            savefilename = os.path.join(b_debugOutputDir,
                                        'ProposalTrace_DocUsage.png')
            plotDocUsageForProposal(docUsageByUID, savefilename=savefilename)

    # EXIT EARLY: error if we didn't create enough "big-enough" states.
    nnzCount = np.sum(xSSslice.getCountVec() >= b_minNumAtomsForNewComp)
    if nnzCount < 2:
        Info['errorMsg'] = \
            "Could not create at least two comps" + \
            " with mass >= %.1f (--%s)" % (
                b_minNumAtomsForNewComp, 'b_minNumAtomsForNewComp')
        BLogger.pprint('Proposal build phase FAILED. ' + Info['errorMsg'])
        BLogger.pprint('')  # Blank line
        return None, Info

    # If here, we have a valid proposal.
    # Need to verify mass conservation
    if hasattr(Dslice, 'word_count') and \
            curModel.obsModel.DataAtomType.count('word') and \
            curModel.getObsModelName().count('Mult'):
        origMass = np.inner(Dslice.word_count, curLPslice['resp'][:, ktarget])
    else:
        if 'resp' in curLPslice:
            origMass = curLPslice['resp'][:, ktarget].sum()
        else:
            origMass = curLPslice['spR'][:, ktarget].sum()
    newMass = xSSslice.getCountVec().sum()
    assert np.allclose(newMass, origMass, atol=1e-6, rtol=0)
    BLogger.pprint('Proposal build phase DONE.' + \
        ' Created %d candidate clusters.' % (Info['Kfinal']))
    BLogger.pprint('')  # Blank line
    return xSSslice, Info
示例#2
0
def cleanupMergeClusters(xSSslice,
                         curModel,
                         xInitLPslice=None,
                         obsSSkeys=None,
                         vocabList=None,
                         b_cleanupMaxNumMergeIters=3,
                         b_cleanupMaxNumAcceptPerIter=1,
                         b_mergeLam=None,
                         b_debugOutputDir=None,
                         pprintCountVec=None,
                         **kwargs):
    ''' Merge all possible pairs of clusters that improve the Ldata objective.

    Returns
    -------
    xSSslice : SuffStatBag
        May have fewer components than K.
    '''
    xSSslice.removeELBOandMergeTerms()
    xSSslice.removeSelectionTerms()
    # Discard all fields unrelated to observation model
    reqFields = set()
    for key in obsSSkeys:
        reqFields.add(key)
    for key in xSSslice._Fields._FieldDims.keys():
        if key not in reqFields:
            xSSslice.removeField(key)

    # For merges, we can crank up value of the topic-word prior hyperparameter,
    # to prioritize only care big differences in word counts across many terms
    tmpModel = curModel.copy()
    if b_mergeLam is not None:
        tmpModel.obsModel.Prior.lam[:] = b_mergeLam

    mergeID = 0
    for trial in range(b_cleanupMaxNumMergeIters):
        tmpModel.obsModel.update_global_params(xSSslice)
        GainLdata = tmpModel.obsModel.calcHardMergeGap_AllPairs(xSSslice)

        triuIDs = np.triu_indices(xSSslice.K, 1)
        posLocs = np.flatnonzero(GainLdata[triuIDs] > 0)
        if posLocs.size == 0:
            # No merges to accept. Stop!
            break

        # Rank the positive pairs from largest to smallest
        sortIDs = np.argsort(-1 * GainLdata[triuIDs][posLocs])
        posLocs = posLocs[sortIDs]

        usedUIDs = set()
        nAccept = 0
        uidpairsToAccept = list()
        origidsToAccept = list()
        for loc in posLocs:
            kA = triuIDs[0][loc]
            kB = triuIDs[1][loc]
            uidA = xSSslice.uids[triuIDs[0][loc]]
            uidB = xSSslice.uids[triuIDs[1][loc]]
            if uidA in usedUIDs or uidB in usedUIDs:
                continue
            usedUIDs.add(uidA)
            usedUIDs.add(uidB)
            uidpairsToAccept.append((uidA, uidB))
            origidsToAccept.append((kA, kB))
            nAccept += 1
            if nAccept >= b_cleanupMaxNumAcceptPerIter:
                break

        for posID, (uidA, uidB) in enumerate(uidpairsToAccept):
            mergeID += 1
            kA, kB = origidsToAccept[posID]
            xSSslice.mergeComps(uidA=uidA, uidB=uidB)

            if xInitLPslice:
                xInitLPslice['DocTopicCount'][:, kA] += \
                    xInitLPslice['DocTopicCount'][:, kB]
                xInitLPslice['DocTopicCount'][:, kB] = -1

            if b_debugOutputDir:
                savefilename = os.path.join(b_debugOutputDir,
                                            'MergeComps_%d.png' % (mergeID))
                # Show side-by-side topics
                bnpy.viz.PlotComps.plotCompsFromHModel(
                    tmpModel,
                    compListToPlot=[kA, kB],
                    vocabList=vocabList,
                    xlabels=[str(uidA), str(uidB)],
                )
                bnpy.viz.PlotUtil.pylab.savefig(savefilename, pad_inches=0)

        if len(uidpairsToAccept) > 0:
            pprintCountVec(xSSslice, uidpairsToAccept=uidpairsToAccept)

        if xInitLPslice:
            badIDs = np.flatnonzero(xInitLPslice['DocTopicCount'][0, :] < 0)
            for kk in reversed(badIDs):
                xInitLPslice['DocTopicCount'] = np.delete(
                    xInitLPslice['DocTopicCount'], kk, axis=1)

    if mergeID > 0 and b_debugOutputDir:
        tmpModel.obsModel.update_global_params(xSSslice)
        outpath = os.path.join(b_debugOutputDir, 'NewComps_AfterMerge.png')
        plotCompsFromSS(
            tmpModel,
            xSSslice,
            outpath,
            vocabList=vocabList,
        )

    if xInitLPslice:
        assert xInitLPslice['DocTopicCount'].min() > -0.000001
        assert xInitLPslice['DocTopicCount'].shape[1] == xSSslice.K
    return xSSslice, xInitLPslice
示例#3
0
def makeSummaryForExistingBirthProposal(
        Dslice,
        curModel,
        curLPslice,
        curSSwhole=None,
        targetUID=None,
        ktarget=None,
        LPkwargs=DefaultLPkwargs,
        lapFrac=0,
        batchID=0,
        b_nRefineSteps=3,
        b_debugOutputDir=None,
        b_method_initCoordAscent='fromprevious',
        vocabList=None,
        **kwargs):
    ''' Create summary that reassigns mass from target given set of comps

    Given set of comps is a fixed proposal from a previously-seen batch.

    Returns
    -------
    xSSslice : SuffStatBag
        Contains exact summaries for reassignment of target mass.
        * Total mass is equal to mass assigned to ktarget in curLPslice
        * Number of components is Kfresh
    Info : dict
        Contains info for detailed debugging of construction process.
    '''
    if targetUID is None:
        targetUID = curSSwhole.uids(ktarget)
    if ktarget is None:
        ktarget = curSSwhole.uid2k(targetUID)
    # START log for this birth proposal
    BLogger.pprint(
        'Extending previous birth for targetUID %s at lap %.2f batch %d' %
        (targetUID, lapFrac, batchID))
    # Grab vocabList, if available.
    if hasattr(Dslice, 'vocabList') and Dslice.vocabList is not None:
        vocabList = Dslice.vocabList
    # Parse input to decide where to save HTML output
    if b_debugOutputDir == 'None':
        b_debugOutputDir = None
    if b_debugOutputDir:
        BLogger.pprint('HTML output:' + b_debugOutputDir)
        # Create snapshot of current model comps
        plotCompsFromSS(curModel,
                        curSSwhole,
                        os.path.join(b_debugOutputDir, 'OrigComps.png'),
                        vocabList=vocabList,
                        compsToHighlight=[ktarget])

    assert targetUID in curSSwhole.propXSS
    xinitSS = curSSwhole.propXSS[targetUID]
    xK = xinitSS.K
    if xK + curSSwhole.K > kwargs['Kmax']:
        errorMsg = 'Cancelled.' + \
            'Adding 2 or more states would exceed budget of %d comps.' % (
                kwargs['Kmax'])
        BLogger.pprint(errorMsg)
        BLogger.pprint('')
        return None, dict(errorMsg=errorMsg)

    # Log messages to describe the initialization.
    # Make a function to pretty-print counts as we refine the initialization
    pprintCountVec = BLogger.makeFunctionToPrettyPrintCounts(xinitSS)
    BLogger.pprint('  Using previous proposal with %d clusters %s.' %
                   (xinitSS.K, '(--b_Kfresh=%d)' % kwargs['b_Kfresh']))
    BLogger.pprint("  Initial uid/counts from previous proposal:")
    BLogger.pprint('   ' + vec2str(xinitSS.uids))
    pprintCountVec(xinitSS)
    BLogger.pprint('  Running %d refinement iterations (--b_nRefineSteps)' %
                   (b_nRefineSteps))

    xSSinitPlusSlice = xinitSS.copy()
    if b_debugOutputDir:
        plotCompsFromSS(curModel,
                        xinitSS,
                        os.path.join(b_debugOutputDir, 'NewComps_Init.png'),
                        vocabList=vocabList)

        # Determine current model objective score
        curModelFWD = curModel.copy()
        curModelFWD.update_global_params(SS=curSSwhole)
        curLdict = curModelFWD.calc_evidence(SS=curSSwhole, todict=1)
        # Track proposal ELBOs as refinement improves things
        propLdictList = list()
        docUsageByUID = dict()
        if curModel.getAllocModelName().count('HDP'):
            for k, uid in enumerate(xinitSS.uids):
                initDocUsage_uid = 0.0
                docUsageByUID[uid] = [initDocUsage_uid]

    # Create initial observation model
    xObsModel = curModel.obsModel.copy()
    xInitLPslice = None
    Info = dict()
    # Run several refinement steps.
    # Each step does a restricted local step to improve
    # the proposed cluster assignments.
    nRefineSteps = np.maximum(1, b_nRefineSteps)
    for rstep in range(nRefineSteps):
        xObsModel.update_global_params(xSSinitPlusSlice)

        # Restricted local step!
        # * xInitSS : specifies obs-model stats used for initialization
        xSSslice, refineInfo = summarizeRestrictedLocalStep(
            Dslice=Dslice,
            curModel=curModel,
            curLPslice=curLPslice,
            curSSwhole=curSSwhole,
            ktarget=ktarget,
            xUIDs=xSSinitPlusSlice.uids,
            xObsModel=xObsModel,
            xInitSS=xSSinitPlusSlice,  # first time in loop <= xinitSS
            xInitLPslice=xInitLPslice,
            LPkwargs=LPkwargs,
            **kwargs)

        xSSinitPlusSlice += xSSslice
        if rstep >= 1:
            xSSinitPlusSlice -= prevSSslice
        prevSSslice = xSSslice

        Info.update(refineInfo)
        # Show diagnostics for new states
        pprintCountVec(xSSslice)
        logLPConvergenceDiagnostics(refineInfo,
                                    rstep=rstep,
                                    b_nRefineSteps=b_nRefineSteps)
        # Get most recent xLPslice for initialization
        if b_method_initCoordAscent == 'fromprevious' and 'xLPslice' in Info:
            xInitLPslice = Info['xLPslice']
        if b_debugOutputDir:
            plotCompsFromSS(curModel,
                            xSSslice,
                            os.path.join(b_debugOutputDir,
                                         'NewComps_Step%d.png' % (rstep + 1)),
                            vocabList=vocabList)
            propSS = curSSwhole.copy()
            propSS.transferMassFromExistingToExpansion(uid=targetUID,
                                                       xSS=xSSslice)
            propModel = curModel.copy()
            propModel.update_global_params(propSS)
            propLdict = propModel.calc_evidence(SS=propSS, todict=1)
            BLogger.pprint(
                "step %d/%d  gainL % .3e  propL % .3e  curL % .3e" %
                (rstep + 1, b_nRefineSteps, propLdict['Ltotal'] -
                 curLdict['Ltotal'], propLdict['Ltotal'], curLdict['Ltotal']))
            propLdictList.append(propLdict)
            if curModel.getAllocModelName().count('HDP'):
                docUsageVec = xSSslice.getSelectionTerm('DocUsageCount')
                for k, uid in enumerate(xSSslice.uids):
                    docUsageByUID[uid].append(docUsageVec[k])

    Info['Kfinal'] = xSSslice.K
    if b_debugOutputDir:
        savefilename = os.path.join(b_debugOutputDir, 'ProposalTrace_ELBO.png')
        plotELBOtermsForProposal(curLdict,
                                 propLdictList,
                                 savefilename=savefilename)
        if curModel.getAllocModelName().count('HDP'):
            savefilename = os.path.join(b_debugOutputDir,
                                        'ProposalTrace_DocUsage.png')
            plotDocUsageForProposal(docUsageByUID, savefilename=savefilename)

    # If here, we have a valid proposal.
    # Need to verify mass conservation
    if hasattr(Dslice, 'word_count') and \
            curModel.obsModel.DataAtomType.count('word') and \
            curModel.getObsModelName().count('Mult'):
        origMass = np.inner(Dslice.word_count, curLPslice['resp'][:, ktarget])
    else:
        if 'resp' in curLPslice:
            origMass = curLPslice['resp'][:, ktarget].sum()
        else:
            origMass = curLPslice['spR'][:, ktarget].sum()
    newMass = xSSslice.getCountVec().sum()
    assert np.allclose(newMass, origMass, atol=1e-6, rtol=0)
    BLogger.pprint('Proposal extension DONE. %d candidate clusters.' %
                   (Info['Kfinal']))
    BLogger.pprint('')
    return xSSslice, Info