Example #1
0
    def test_plot_helpers(self):
        from pygsti.report import plothelpers as ph

        self.assertEqual(ph._eformat(0.1, "compacthp"),".10")
        self.assertEqual(ph._eformat(1.0, "compacthp"),"1.0")
        self.assertEqual(ph._eformat(5.2, "compacthp"),"5.2")
        self.assertEqual(ph._eformat(63.2, "compacthp"),"63")
        self.assertEqual(ph._eformat(2.1e-4, "compacthp"),"2m4")
        self.assertEqual(ph._eformat(2.1e+4, "compacthp"),"2e4")
        self.assertEqual(ph._eformat(-3.2e-4, "compacthp"),"-3m4")
        self.assertEqual(ph._eformat(-3.2e+4, "compacthp"),"-3e4")
        self.assertEqual(ph._eformat(4e+40, "compacthp"),"*40")
        self.assertEqual(ph._eformat(6e+102, "compacthp"),"B")
        self.assertEqual(ph._eformat(10, "compacthp"),"10")
        self.assertEqual(ph._eformat(1.234, 2),"1.23")
        self.assertEqual(ph._eformat(-1.234, 2),"-1.23")
        self.assertEqual(ph._eformat(-1.234, "foobar"), "-1.234") #just prints in general format

        subMxs = np.nan * np.ones((3,3,2,2),'d') # a 3x3 grid of 2x2 matrices
        nBoxes, dof_per_box = ph._compute_num_boxes_dof(subMxs, sumUp=True, element_dof=1)
        self.assertEqual(nBoxes, 0)
        self.assertEqual(dof_per_box, None)

        subMxs[0,0,1,1] = 1.0 # matrix [0,0] has a single non-Nan element
        subMxs[0,2,0,1] = 1.0 
        subMxs[0,2,1,1] = 1.0 # matrix [0,2] has a two non-Nan elements

        # now the mxs that aren't all-NaNs don't all have the same # of Nans => warning
        #self.assertWarns(ph._compute_num_boxes_dof, subMxs, sumUp=True, element_dof=1)
        ph._compute_num_boxes_dof( subMxs, sumUp=True, element_dof=1) # Python2.7 doesn't always warn...
Example #2
0
def plotly_to_matplotlib(pygsti_fig,
                         save_to=None,
                         fontsize=12,
                         prec='compacthp',
                         box_labels_font_size=6):
    """
    Convert a pygsti (plotly) figure to a matplotlib figure.

    Parameters
    ----------
    pygsti_fig : ReportFigure
        A pyGSTi figure.

    save_to : str
        Output filename.  Extension determines type.  If None, then the
        matplotlib figure is returned instead of saved.

    fontsize : int, optional
        Base fontsize to use for converted figure.

    prec : int or {"compact","compacth"}
        Digits of precision to include in labels.

    box_labels_font_size : int, optional
        The size for labels on the boxes. If 0 then no labels are
        put on the boxes

    Returns
    -------
    matplotlib.Figure
        Matplotlib figure, unless save_to is not None, in which case
        the figure is closed and None is returned.
    """
    numMPLFigs = len(_plt.get_fignums())
    fig = pygsti_fig.plotlyfig
    data_trace_list = fig['data']

    if 'special' in pygsti_fig.metadata:
        if pygsti_fig.metadata['special'] == "keyplot":
            return special_keyplot(pygsti_fig, save_to, fontsize)
        else:
            raise ValueError("Invalid `special` label: %s" %
                             pygsti_fig.metadata['special'])

    #if axes is None:
    mpl_fig, axes = _plt.subplots()  # create a new figure if no axes are given

    layout = fig['layout']
    h, w = layout['height'], layout['width']
    # todo: get margins and subtract from h,w

    if mpl_fig is not None and w is not None and h is not None:
        mpl_size = w / 100.0, h / 100.0  # heusistic
        mpl_fig.set_size_inches(*mpl_size)  # was 12,8 for "super" color plot
        pygsti_fig.metadata[
            'mpl_fig_size'] = mpl_size  # record for later use by rendering commands

    def get(obj, x, default):
        """ Needed b/c in plotly v3 layout no longer is a dict """
        try:
            ret = obj[x]
            return ret if (ret is not None) else default
        except KeyError:
            return default
        raise ValueError(
            "Non-KeyError raised when trying to access a plotly hierarchy object."
        )

    xaxis, yaxis = layout['xaxis'], layout['yaxis']
    #annotations = get(layout,'annotations',[])
    title = get(layout, 'title', None)
    shapes = get(layout, 'shapes', [])  # assume only shapes are grid lines
    bargap = get(layout, 'bargap', 0)

    xlabel = get(xaxis, 'title', None)
    ylabel = get(yaxis, 'title', None)
    xlabels = get(xaxis, 'ticktext', None)
    ylabels = get(yaxis, 'ticktext', None)
    xtickvals = get(xaxis, 'tickvals', None)
    ytickvals = get(yaxis, 'tickvals', None)
    xaxistype = get(xaxis, 'type', None)
    yaxistype = get(yaxis, 'type', None)
    xaxisside = get(xaxis, 'side', 'bottom')
    yaxisside = get(yaxis, 'side', 'left')
    xtickangle = get(xaxis, 'tickangle', 0)
    xlim = get(xaxis, 'range', None)
    ylim = get(yaxis, 'range', None)

    if xaxisside == "top":
        axes.xaxis.set_label_position('top')
        axes.xaxis.tick_top()
        #axes.yaxis.set_ticks_position('both')

    if yaxisside == "right":
        axes.yaxis.set_label_position('right')
        axes.yaxis.tick_right()
        #axes.yaxis.set_ticks_position('both')

    if title is not None:
        # Sometimes Title object still is nested
        title_text = title if isinstance(title, str) else get(
            title, 'text', '')
        if xaxisside == "top":
            axes.set_title(mpl_process_lbl(title_text), fontsize=fontsize,
                           y=4)  # push title up higher
        axes.set_title(mpl_process_lbl(title_text), fontsize=fontsize)

    if xlabel is not None:
        xlabel_text = xlabel if isinstance(xlabel, str) else get(
            xlabel, 'text', '')
        axes.set_xlabel(mpl_process_lbl(xlabel_text), fontsize=fontsize)

    if ylabel is not None:
        ylabel_text = ylabel if isinstance(ylabel, str) else get(
            ylabel, 'text', '')
        axes.set_ylabel(mpl_process_lbl(ylabel_text), fontsize=fontsize)

    if xtickvals is not None:
        axes.set_xticks(xtickvals, minor=False)

    if ytickvals is not None:
        axes.set_yticks(ytickvals, minor=False)

    if xlabels is not None:
        axes.set_xticklabels(mpl_process_lbls(xlabels),
                             rotation=0,
                             fontsize=(fontsize - 2))

    if ylabels is not None:
        axes.set_yticklabels(mpl_process_lbls(ylabels),
                             fontsize=(fontsize - 2))

    if xtickangle != 0:
        _plt.xticks(
            rotation=-xtickangle
        )  # minus b/c ploty & matplotlib have different sign conventions

    if xaxistype == 'log':
        axes.set_xscale("log")
    if yaxistype == 'log':
        axes.set_yscale("log")

    if xlim is not None:
        if xaxistype == 'log':  # plotly's limits are already log10'd in this case
            xlim = 10.0**xlim[0], 10.0**xlim[1]  # but matplotlib's aren't
        axes.set_xlim(xlim)

    if ylim is not None:
        if yaxistype == 'log':  # plotly's limits are already log10'd in this case
            ylim = 10.0**ylim[0], 10.0**ylim[1]  # but matplotlib's aren't
        axes.set_ylim(ylim)

    #figure out barwidth and offsets for bar plots
    num_bars = sum([get(d, 'type', '') == 'bar' for d in data_trace_list])
    currentBarOffset = 0
    barWidth = (1.0 - bargap) / num_bars if num_bars > 0 else 1.0

    #process traces
    handles = []
    labels = []  # for the legend
    boxes = []  # for violins
    for traceDict in data_trace_list:
        typ = get(traceDict, 'type', 'unknown')

        name = get(traceDict, 'name', None)
        showlegend = get(traceDict, 'showlegend', True)

        if typ == "heatmap":
            #colorscale = get(traceDict,'colorscale','unknown')
            # traceDict['z'] is *normalized* already - maybe would work here but not for box value labels
            plt_data = pygsti_fig.metadata['plt_data']
            show_colorscale = get(traceDict, 'showscale', True)

            mpl_size = (plt_data.shape[1] * 0.5, plt_data.shape[0] * 0.5)
            mpl_fig.set_size_inches(*mpl_size)
            #pygsti_fig.metadata['mpl_fig_size'] = mpl_size #record for later use by rendering commands

            colormap = pygsti_fig.colormap
            assert (colormap
                    is not None), 'Must separately specify a colormap...'
            norm, cmap = colormap.create_matplotlib_norm_and_cmap()

            masked_data = _np.ma.array(plt_data, mask=_np.isnan(plt_data))
            heatmap = axes.pcolormesh(masked_data, cmap=cmap, norm=norm)

            axes.set_xlim(0, plt_data.shape[1])
            axes.set_ylim(0, plt_data.shape[0])

            if xtickvals is not None:
                xtics = _np.array(
                    xtickvals) + 0.5  # _np.arange(plt_data.shape[1])+0.5
                axes.set_xticks(xtics, minor=False)

            if ytickvals is not None:
                ytics = _np.array(
                    ytickvals) + 0.5  # _np.arange(plt_data.shape[0])+0.5
                axes.set_yticks(ytics, minor=False)

            grid = bool(len(shapes) > 1)
            if grid:

                def _get_minor_tics(t):
                    return [(t[i] + t[i + 1]) / 2.0 for i in range(len(t) - 1)]

                axes.set_xticks(_get_minor_tics(xtics), minor=True)
                axes.set_yticks(_get_minor_tics(ytics), minor=True)
                axes.grid(which='minor',
                          axis='both',
                          linestyle='-',
                          linewidth=2)

            off = False  # Matplotlib used to allow 'off', but now False should be used
            if xlabels is None and ylabels is None:
                axes.tick_params(labelcolor='w',
                                 top=off,
                                 bottom=off,
                                 left=off,
                                 right=off)  # white tics
            else:
                axes.tick_params(top=off, bottom=off, left=off, right=off)

            #print("DB ann = ", len(annotations))
            #boxLabels = bool( len(annotations) >= 1 ) #TODO: why not plt_data.size instead of 1?
            #boxLabels = True  # maybe should always be true?
            if box_labels_font_size > 0:
                # Write values on colored squares
                for y in range(plt_data.shape[0]):
                    for x in range(plt_data.shape[1]):
                        if _np.isnan(plt_data[y, x]): continue
                        assert (_np.isfinite(plt_data[y, x])
                                ), "%s is not finite!" % str(plt_data[y, x])
                        axes.text(
                            x + 0.5,
                            y + 0.5,
                            mpl_process_lbl(_eformat(plt_data[y, x], prec),
                                            math=True),
                            horizontalalignment='center',
                            verticalalignment='center',
                            color=mpl_besttxtcolor(plt_data[y, x], cmap, norm),
                            fontsize=box_labels_font_size)

            if show_colorscale:
                cbar = _plt.colorbar(heatmap)
                cbar.ax.tick_params(labelsize=(fontsize - 2))

        elif typ == "scatter":
            mode = get(traceDict, 'mode', 'lines')
            marker = get(traceDict, 'marker', None)
            line = get(traceDict, 'line', None)
            if marker and (line is None):
                line = marker['line']  # 2nd attempt to get line props

            if marker:
                color = get(marker, 'color', None)
            if line and (color is None):
                color = get(line, 'color', None)
            if color is None:
                color = 'rgb(0,0,0)'

            if isinstance(color, tuple):
                color = [mpl_color(c) for c in color]
            else:
                color = mpl_color(color)

            linewidth = float(line['width']) if (
                line and get(line, 'width', None) is not None) else 1.0

            x = y = None
            if 'x' in traceDict and 'y' in traceDict:
                x = traceDict['x']
                y = traceDict['y']
            elif 'r' in traceDict and 't' in traceDict:
                x = traceDict['r']
                y = traceDict['t']

            assert (x is not None and y
                    is not None), "x and y both None in trace: %s" % traceDict
            if mode == 'lines':
                if isinstance(color, list):
                    raise ValueError(
                        'List of colors incompatible with lines mode')
                lines = _plt.plot(x,
                                  y,
                                  linestyle='-',
                                  marker=None,
                                  color=color,
                                  linewidth=linewidth)
            elif mode == 'markers':
                lines = _plt.scatter(x, y, marker=".", color=color)
            elif mode == 'lines+markers':
                if isinstance(color, list):
                    # List of colors only works for markers with scatter, have default black line
                    lines = _plt.plot(x,
                                      y,
                                      linestyle='-',
                                      color=(0, 0, 0),
                                      linewidth=linewidth)
                    _plt.scatter(x, y, marker='.', color=color)
                else:
                    lines = _plt.plot(x,
                                      y,
                                      linestyle='-',
                                      marker='.',
                                      color=color,
                                      linewidth=linewidth)
            else:
                raise ValueError("Unknown mode: %s" % mode)

            if showlegend and name:
                handles.append(lines[0])
                labels.append(name)

        elif typ == "scattergl":  # currently used only for colored points...
            x = traceDict['x']
            y = traceDict['y']
            assert (x is not None and y
                    is not None), "x and y both None in trace: %s" % traceDict

            colormap = pygsti_fig.colormap
            if colormap:
                norm, cmap = colormap.create_matplotlib_norm_and_cmap()
                s = _plt.scatter(x, y, c=y, s=50, cmap=cmap, norm=norm)
            else:
                s = _plt.scatter(x, y, c=y, s=50, cmap='gray')

            if showlegend and name:
                handles.append(s)
                labels.append(name)

        elif typ == "bar":
            xlabels = [str(xl) for xl in traceDict['x']
                       ]  # x "values" are actually bar labels in plotly

            #always grey=pos, red=neg type of bar plot for now (since that's all pygsti uses)
            y = _np.asarray(traceDict['y'])
            if 'plt_yerr' in pygsti_fig.metadata:
                yerr = pygsti_fig.metadata['plt_yerr']
            else:
                yerr = None

            # actual x values are just the integers + offset
            x = _np.arange(y.size) + currentBarOffset
            currentBarOffset += barWidth  # so next bar trace will be offset correctly

            marker = get(traceDict, 'marker', None)
            if marker and ('color' in marker):
                if isinstance(marker['color'], str):
                    color = mpl_color(marker['color'])
                elif isinstance(marker['color'], list):
                    color = [mpl_color(c) for c in marker['color']
                             ]  # b/c axes.bar can take a list of colors
                else:
                    color = "gray"

            if yerr is None:
                axes.bar(x, y, barWidth, color=color)
            else:
                axes.bar(x, y, barWidth, color=color, yerr=yerr.flatten().real)

            if xtickvals is not None:
                xtics = _np.array(
                    xtickvals) + 0.5  # _np.arange(plt_data.shape[1])+0.5
            else:
                xtics = x
            axes.set_xticks(xtics, minor=False)
            axes.set_xticklabels(mpl_process_lbls(xlabels),
                                 rotation=0,
                                 fontsize=(fontsize - 4))

        elif typ == "histogram":
            #histnorm = get(traceDict,'histnorm',None)
            marker = get(traceDict, 'marker', None)
            color = mpl_color(marker['color'] if marker
                              and isinstance(marker['color'], str) else "gray")
            xbins = traceDict['xbins']
            histdata = traceDict['x']

            if abs(xbins['size']) < 1e-6:
                histBins = 1
            else:
                histBins = int(
                    round((xbins['end'] - xbins['start']) / xbins['size']))

            histdata_finite = _np.take(
                histdata, _np.where(_np.isfinite(histdata)))[
                    0]  # take gives back (1,N) shaped array (why?)
            if yaxistype == 'log':
                if len(histdata_finite) == 0:
                    axes.set_yscale(
                        "linear"
                    )  # no data, and will get an error with log-scale, so switch to linear

            #histMin = min( histdata_finite ) if cmapFactory.vmin is None else cmapFactory.vmin
            #histMax = max( histdata_finite ) if cmapFactory.vmax is None else cmapFactory.vmax
            #_plt.hist(_np.clip(histdata_finite,histMin,histMax), histBins,
            #          range=[histMin, histMax], facecolor='gray', align='mid')
            _, _, patches = _plt.hist(histdata_finite,
                                      histBins,
                                      facecolor=color,
                                      align='mid')

            #If we've been given an array of colors
            if marker and ('color' in marker) and isinstance(
                    marker['color'], list):
                for p, c in zip(patches, marker['color']):
                    _plt.setp(p, 'facecolor', mpl_color(c))

        elif typ == "box":
            boxes.append(traceDict)

    if len(boxes) > 0:
        _plt.violinplot([box['y'] for box in boxes],
                        [box['x0'] for box in boxes],
                        points=10,
                        widths=1.,
                        showmeans=False,
                        showextrema=False,
                        showmedians=False)
        # above kwargs taken from Tim's original RB plot - we could set some of
        # these from boxes[0]'s properties like 'boxmean' (a boolean) FUTURE?

    extraartists = [axes]
    if len(handles) > 0:
        lgd = _plt.legend(handles,
                          labels,
                          bbox_to_anchor=(1.01, 1.0),
                          borderaxespad=0.,
                          loc="upper left")
        extraartists.append(lgd)

    if save_to:
        _gc.collect(
        )  # too many open files (b/c matplotlib doesn't close everything) can cause the below to fail
        _plt.savefig(save_to,
                     bbox_extra_artists=extraartists,
                     bbox_inches='tight')  # need extra artists otherwise
        #axis labels get clipped
        _plt.cla()
        _plt.close(mpl_fig)
        del mpl_fig
        _gc.collect()  # again, to be safe...
        if len(_plt.get_fignums()) != numMPLFigs:
            raise ValueError(
                "WARNING: MORE FIGURES OPEN NOW (%d) THAN WHEN WE STARTED %d)!!"
                % (len(_plt.get_fignums()), numMPLFigs))
        return None  # figure is closed!
    else:
        return mpl_fig