示例#1
0
def plotCovMatFromHModel(hmodel,
                         compListToPlot=None,
                         compsToHighlight=None,
                         proba_thr=0.001,
                         ax_handle=None,
                         **kwargs):
    ''' Plot square image of covariance matrix for each component.

    Parameters
    -------
    hmodel : bnpy HModel object
    compListToPlot : array-like of integer IDs of components within hmodel
    compsToHighlight : int or array-like
        integer IDs to highlight
        if None, all components get unique colors
        if not None, only highlighted components get colors.
    proba_thr : float
        Minimum weight assigned to component in order to be plotted.
        All components with weight below proba_thr are ignored.
    '''

    nRow = 2
    nCol = int(np.ceil(hmodel.obsModel.K / 2.0))
    if ax_handle is None:
        ax_handle = pylab.subplots(nrows=nRow,
                                   ncols=nCol,
                                   figsize=(nCol * 2, nRow * 2))
    else:
        pylab.subplots(nrows=nRow, ncols=nCol, num=ax_handle.number)

    if compsToHighlight is not None:
        compsToHighlight = np.asarray(compsToHighlight)
        if compsToHighlight.ndim == 0:
            compsToHighlight = np.asarray([compsToHighlight])
    else:
        compsToHighlight = list()
    if compListToPlot is None:
        compListToPlot = np.arange(0, hmodel.obsModel.K)

    if hmodel.allocModel.K == hmodel.obsModel.K:
        w = hmodel.allocModel.get_active_comp_probs()
    else:
        w = np.ones(hmodel.obsModel.K)

    colorID = 0
    for plotID, kk in enumerate(compListToPlot):
        if w[kk] < proba_thr and kk not in compsToHighlight:
            Sigma = getEmptyCompSigmaImage(hmodel.obsModel.D)
            clim = [0, 1]
        else:
            Sigma = hmodel.obsModel.get_covar_mat_for_comp(kk)
            clim = [-.25, 1]
        pylab.subplot(nRow, nCol, plotID + 1)
        pylab.imshow(Sigma, interpolation='nearest', cmap='hot', clim=clim)
        pylab.xticks([])
        pylab.yticks([])
        pylab.xlabel('%.2f' % (w[kk]))
        if kk in compsToHighlight:
            pylab.xlabel('***')

    for emptyID in range(plotID + 1, nRow * nCol):
        aH = pylab.subplot(nRow, nCol, emptyID + 1)
        aH.axis('off')
示例#2
0
def plotManyPanelsByPVar(jpathPattern='/tmp/',
                         pvar=None,
                         pvals=None,
                         W=5,
                         H=4,
                         savefilename=None,
                         doShowNow=False,
                         **kwargs):
    ''' Create line plots for jobs matching pattern and provided kwargs
    '''
    if pvar is None:
        jpathList = [jpathPattern]
        pvar = None
        pvals = [None]
    else:
        prefixfilepath = os.path.sep.join(jpathPattern.split(os.path.sep)[:-1])
        PPListMap = makePPListMapFromJPattern(jpathPattern)
        if pvals is None:
            pvals = PPListMap[pvar]
        else:
            pvals = [p for p in pvals if p in PPListMap[pvar]]
        jpathList = makeListOfJPatternsWithSpecificVals(
            PPListMap,
            prefixfilepath=prefixfilepath,
            key=pvar,
            vals=pvals,
            **kwargs)

    nrows = 1
    ncols = len(pvals)
    pylab.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * W, nrows * H))

    axH = None
    for panelID, panel_jobPattern in enumerate(jpathList):
        axH = pylab.subplot(nrows, ncols, panelID + 1, sharey=axH, sharex=axH)
        # Only show legend on first plot
        if panelID > 0 and 'loc' in kwargs:
            kwargs['loc'] = None
        kwargs['doShowNow'] = False
        plotMultipleLinesByLVar(panel_jobPattern, **kwargs)
        if pvar is not None:
            pylab.title('%s=%s' % (pvar, pvals[panelID]))

    pylab.subplots_adjust(bottom=0.15, wspace=0.5)

    if savefilename is not None:
        try:
            pylab.show(block=False)
        except TypeError:
            pass  # when using IPython notebook
        pylab.savefig(savefilename, bbox_inches='tight', pad_inches=0)
    elif doShowNow:
        try:
            pylab.show(block=True)
        except TypeError:
            pass  # when using IPython notebook
    Info = dict(
        nrows=nrows,
        ncols=ncols,
    )
    return Info
示例#3
0
                                              lapFrac=aiter)
            if len(MInfo['m_UIDPairs']) > 0:
                for (uidA, uidB) in MInfo['m_UIDPairs']:
                    combinedSS.mergeComps(uidA=uidA, uidB=uidB)
                    trainSS.mergeComps(uidA=uidA, uidB=uidB)
                combinedModel.update_global_params(combinedSS)

    print()
    print("Plotting final combined model!")
    print("Each plot shows 25 samples of image patches from that cluster")

    PRNG = np.random.RandomState(0)
    for k in range(trainSS.K):
        Sigma_k = combinedModel.obsModel.get_covar_mat_for_comp(k)
        X_k = PRNG.multivariate_normal(np.zeros(64), Sigma_k, size=25)
        figH, axList = pylab.subplots(nrows=5, ncols=5)
        figH.canvas.set_window_title('Cluster %d: Sample patches' % (k))

        ii = 0
        for r in range(5):
            for c in range(5):
                axList[r, c].imshow(X_k[ii].reshape((8, 8)),
                                    interpolation='nearest',
                                    cmap='gray_r',
                                    vmin=-0.1,
                                    vmax=0.1)
                ii += 1
                axList[r, c].set_xticks([])
                axList[r, c].set_yticks([])

    pylab.show()
示例#4
0
def plotCompsAsSquareImages(
        phi,
        compsToHighlight=None,
        compListToPlot=None,
        activeCompIDs=None,
        xlabels=[],
        Kmax=50,
        W=1,
        H=1,
        figH=None,
        vocabList=None,  # catchall
        **kwargs):
    curImshowArgs = dict(**imshowArgs)
    curImshowArgs.update(kwargs)

    if len(xlabels) > 0:
        H = 1.5 * H
    K, V = phi.shape
    sqrtV = int(np.sqrt(V))
    assert np.allclose(sqrtV, np.sqrt(V))

    if compListToPlot is None:
        compListToPlot = np.arange(0, K)
    if activeCompIDs is None:
        activeCompIDs = np.arange(0, K)
    compsToHighlight = np.asarray(compsToHighlight)
    if compsToHighlight.ndim == 0:
        compsToHighlight = np.asarray([compsToHighlight])

    # Create Figure
    Kplot = np.minimum(len(compListToPlot), Kmax)
    ncols = 3  # int(np.ceil(Kplot / float(nrows)))
    nrows = int(np.ceil(Kplot / float(ncols)))
    if figH is None:
        # Make a new figure
        figH, ha = pylab.subplots(nrows=nrows,
                                  ncols=ncols,
                                  figsize=(ncols * W, nrows * H))
    else:
        # Use existing figure
        # TODO: Find a way to make this call actually change the figsize
        figH, ha = pylab.subplots(nrows=nrows,
                                  ncols=ncols,
                                  figsize=(ncols * W, nrows * H),
                                  num=figH.number)

    for plotID, compID in enumerate(compListToPlot):
        if plotID >= Kmax:
            print('DISPLAY LIMIT EXCEEDED. Showing %d/%d components' \
                % (plotID, len(activeCompIDs)))
            break

        if compID not in activeCompIDs:
            aH = pylab.subplot(nrows, ncols, plotID + 1)
            aH.axis('off')
            continue

        kk = np.flatnonzero(compID == activeCompIDs)[0]
        phiIm = np.reshape(phi[kk, :], (sqrtV, sqrtV))

        ax = pylab.subplot(nrows, ncols, plotID + 1)
        pylab.imshow(phiIm, aspect=1.0, **curImshowArgs)
        pylab.xticks([])
        pylab.yticks([])

        # Draw colored border around highlighted topics
        if compID in compsToHighlight:
            [i.set_color('green') for i in ax.spines.values()]
            [i.set_linewidth(3) for i in ax.spines.values()]

        if xlabels is not None:
            if len(xlabels) > 0:
                pylab.xlabel(xlabels[plotID], fontsize=15)

    # Disable empty plots!
    for kdel in range(plotID + 2, nrows * ncols + 1):
        aH = pylab.subplot(nrows, ncols, kdel)
        aH.axis('off')

    # Fix margins between subplots
    pylab.subplots_adjust(wspace=0.04,
                          hspace=0.04,
                          left=0.01,
                          right=0.99,
                          top=0.99,
                          bottom=0.01)
    return figH
示例#5
0
def plotCompsFromWordCounts(WordCounts=None,
                            topics_KV=None,
                            vocabList=None,
                            compListToPlot=None,
                            compsToHighlight=None,
                            xlabels=None,
                            wordSizeLimit=10,
                            Ktop=10,
                            Kmax=32,
                            H=2.5,
                            W=2.0,
                            figH=None,
                            ncols=10,
                            ax_list=None,
                            fontsize=10,
                            proba_fmt_str="%.4f",
                            **kwargs):
    ''' Create subplots of top 10 words from each topic, from word count array.

    Post Condition
    --------------
    Current matplotlib figure has subplot for each topic.
    '''
    if vocabList is None:
        raise ValueError('Missing vocabList. Cannot display topics.')
    if WordCounts is not None:
        WordCounts = np.asarray(WordCounts, dtype=np.float64)
        if WordCounts.ndim == 1:
            WordCounts = WordCounts[np.newaxis, :]
        K, vocab_size = WordCounts.shape
        N = np.sum(WordCounts, axis=1)
    else:
        topics_KV = np.asarray(topics_KV, dtype=np.float64)
        K, vocab_size = topics_KV.shape

    if compListToPlot is None:
        compListToPlot = np.arange(0, K)
    Kplot = np.minimum(len(compListToPlot), Kmax)
    if len(compListToPlot) > Kmax:
        print('DISPLAY LIMIT EXCEEDED. Showing %d/%d components' \
            % (Kplot, len(compListToPlot)))
    compListToPlot = compListToPlot[:Kplot]
    # Parse comps to highlight
    compsToHighlight = np.asarray(compsToHighlight)
    if compsToHighlight.ndim == 0:
        compsToHighlight = np.asarray([compsToHighlight])
    nrows = int(np.ceil(Kplot / float(ncols)))
    # Create Figure
    if ax_list is None:
        fig_h, ax_list = pylab.subplots(nrows=nrows,
                                        ncols=ncols,
                                        figsize=(ncols * W, nrows * H))
    if isinstance(ax_list, np.ndarray):
        ax_list = ax_list.flatten().tolist()
    elif str(type(ax_list)).count("matplotlib"):
        ax_list = [ax_list]  # degenerate case where subplots returns single ax
    assert isinstance(ax_list, list)
    n_images_viewable = len(ax_list)
    n_images_to_plot = len(compListToPlot)

    for plotID, compID in enumerate(compListToPlot):
        cur_ax_h = ax_list[plotID]  #pylab.subplot(nrows, ncols, plotID + 1)

        topicMultilineStr = ''
        if WordCounts is None:
            topIDs = np.argsort(-1 * topics_KV[compID])
        else:
            topIDs = np.argsort(-1 * WordCounts[compID])
        for wID in topIDs[:Ktop]:
            if WordCounts is not None and WordCounts[compID, wID] > 0:
                wctStr = count2str(WordCounts[compID, wID])
                topicMultilineStr += '%s %s\n' % (
                    wctStr, vocabList[wID][:wordSizeLimit])
            else:
                topicMultilineStr += (proba_fmt_str + " %s\n") % (
                    topics_KV[compID, wID], vocabList[wID][:wordSizeLimit])
        cur_ax_h.text(0,
                      0,
                      topicMultilineStr,
                      fontsize=fontsize,
                      family=u'monospace')
        cur_ax_h.set_xlim([0, 1])
        cur_ax_h.set_ylim([0, 1])
        cur_ax_h.set_xticks([])
        cur_ax_h.set_yticks([])

        # Draw colored border around highlighted topics
        if compID in compsToHighlight:
            [i.set_color('green') for i in ax.spines.values()]
            [i.set_linewidth(3) for i in ax.spines.values()]
        if xlabels is not None:
            if len(xlabels) > 0:
                cur_ax_h.set_xlabel(xlabels[plotID], fontsize=11)

    # Disable empty plots
    for k, ax_h in enumerate(ax_list[n_images_to_plot:]):
        ax_h.axis('off')

    return figH, ax_list
示例#6
0
def plotExampleBarsDocs(Data,
                        docIDsToPlot=None,
                        figID=None,
                        vmax=None,
                        nDocToPlot=16,
                        doShowNow=False,
                        seed=0,
                        randstate=np.random.RandomState(0),
                        xlabels=None,
                        W=1,
                        H=1,
                        **kwargs):
    kwargs['vmin'] = 0
    kwargs['interpolation'] = 'nearest'
    if vmax is not None:
        kwargs['vmax'] = vmax
    if seed is not None:
        randstate = np.random.RandomState(seed)
    V = Data.vocab_size
    sqrtV = int(np.sqrt(V))
    assert np.allclose(sqrtV * sqrtV, V)
    if docIDsToPlot is not None:
        nDocToPlot = len(docIDsToPlot)
    else:
        size = np.minimum(Data.nDoc, nDocToPlot)
        docIDsToPlot = randstate.choice(Data.nDoc, size=size, replace=False)
    ncols = 5
    nrows = int(np.ceil(nDocToPlot / float(ncols)))
    if vmax is None:
        DocWordArr = Data.getDocTypeCountMatrix()
        vmax = int(np.max(np.percentile(DocWordArr, 98, axis=0)))

    if figID is None:
        figH, ha = pylab.subplots(nrows=nrows,
                                  ncols=ncols,
                                  figsize=(ncols * W, nrows * H))

    for plotPos, docID in enumerate(docIDsToPlot):
        start = Data.doc_range[docID]
        stop = Data.doc_range[docID + 1]
        wIDs = Data.word_id[start:stop]
        wCts = Data.word_count[start:stop]
        docWordHist = np.zeros(V)
        docWordHist[wIDs] = wCts
        squareIm = np.reshape(docWordHist, (sqrtV, sqrtV))
        pylab.subplot(nrows, ncols, plotPos + 1)
        pylab.imshow(squareIm, **kwargs)
        pylab.axis('image')
        pylab.xticks([])
        pylab.yticks([])
        if xlabels is not None:
            pylab.xlabel(xlabels[plotPos])

    # Disable empty plots!
    for kdel in range(plotPos + 2, nrows * ncols + 1):
        aH = pylab.subplot(nrows, ncols, kdel)
        aH.axis('off')

    # Fix margins between subplots
    pylab.subplots_adjust(wspace=0.04,
                          hspace=0.04,
                          left=0.01,
                          right=0.99,
                          top=0.99,
                          bottom=0.01)
    if doShowNow:
        pylab.show()
示例#7
0
def show_square_images(topics_KV=None,
                       xlabels=[],
                       max_n_images=50,
                       ncols=5,
                       ax_list=None,
                       im_width=1,
                       im_height=1,
                       fontsize=10,
                       **kwargs):
    ''' Show provided vectors as square images

    Post Condition
    --------------
    Provided axes have plots updated.
    '''
    global imshowArgs
    local_imshowArgs = dict(**imshowArgs)
    for key in local_imshowArgs:
        if key in kwargs:
            local_imshowArgs[key] = kwargs[key]

    K, V = topics_KV.shape
    sqrtV = int(np.sqrt(V))
    assert np.allclose(sqrtV, np.sqrt(V))

    n_images_to_plot = np.minimum(K, max_n_images)
    ncols = np.minimum(ncols, n_images_to_plot)
    if ax_list is None:
        # Make a new figure
        nrows = int(np.ceil(n_images_to_plot / float(ncols)))
        fig_h, ax_list = pylab.subplots(nrows=nrows,
                                        ncols=ncols,
                                        figsize=(ncols * im_width,
                                                 nrows * im_height))

    if isinstance(ax_list, np.ndarray):
        ax_list = ax_list.flatten().tolist()
    elif str(type(ax_list)).count("matplotlib"):
        ax_list = [ax_list]  # degenerate case where subplots returns single ax
    assert isinstance(ax_list, list)
    n_images_viewable = len(ax_list)

    # Plot each row as square image
    for k, ax_h in enumerate(ax_list[:n_images_to_plot]):
        cur_im_sVsV = np.reshape(topics_KV[k, :], (sqrtV, sqrtV))
        ax_h.imshow(cur_im_sVsV, **local_imshowArgs)
        ax_h.set_xticks([])
        ax_h.set_yticks([])

        if xlabels is not None:
            if len(xlabels) > 0:
                ax_h.set_xlabel(xlabels[k], fontsize=fontsize)

    # Disable empty plots
    for k, ax_h in enumerate(ax_list[n_images_to_plot:]):
        ax_h.axis('off')

    # Fix margins between subplots
    #pylab.subplots_adjust(
    #    wspace=0.1,
    #    hspace=0.1 * nrows,
    #    left=0.001, right=0.999,
    #   bottom=0.1, top=0.999)
    return ax_list