Beispiel #1
0
def test_Bug_2543():
    # Test that it possible to add all values to itself / deepcopy
    # This was not possible because validate_bool_maybe_none did not
    # accept None as an argument.
    # https://github.com/matplotlib/matplotlib/issues/2543
    # We filter warnings at this stage since a number of them are raised
    # for deprecated rcparams as they should. We dont want these in the
    # printed in the test suite.
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore',
                                message='.*(deprecated|obsolete)',
                                category=UserWarning)
        with mpl.rc_context():
            _copy = mpl.rcParams.copy()
            for key in six.iterkeys(_copy):
                mpl.rcParams[key] = _copy[key]
            mpl.rcParams['text.dvipnghack'] = None
        with mpl.rc_context():
            from copy import deepcopy
            _deep_copy = deepcopy(mpl.rcParams)
        # real test is that this does not raise
        assert_true(validate_bool_maybe_none(None) is None)
        assert_true(validate_bool_maybe_none("none") is None)
        _fonttype = mpl.rcParams['svg.fonttype']
        assert_true(_fonttype == mpl.rcParams['svg.embed_char_paths'])
        with mpl.rc_context():
            mpl.rcParams['svg.embed_char_paths'] = False
            assert_true(mpl.rcParams['svg.fonttype'] == "none")
Beispiel #2
0
def test_rcparams():
    mpl.rc('text', usetex=False)
    mpl.rc('lines', linewidth=22)

    usetex = mpl.rcParams['text.usetex']
    linewidth = mpl.rcParams['lines.linewidth']
    fname = os.path.join(os.path.dirname(__file__), 'test_rcparams.rc')

    # test context given dictionary
    with mpl.rc_context(rc={'text.usetex': not usetex}):
        assert mpl.rcParams['text.usetex'] == (not usetex)
    assert mpl.rcParams['text.usetex'] == usetex

    # test context given filename (mpl.rc sets linewidth to 33)
    with mpl.rc_context(fname=fname):
        assert mpl.rcParams['lines.linewidth'] == 33
    assert mpl.rcParams['lines.linewidth'] == linewidth

    # test context given filename and dictionary
    with mpl.rc_context(fname=fname, rc={'lines.linewidth': 44}):
        assert mpl.rcParams['lines.linewidth'] == 44
    assert mpl.rcParams['lines.linewidth'] == linewidth

    # test rc_file
    mpl.rc_file(fname)
    assert mpl.rcParams['lines.linewidth'] == 33
def test_rcparams():

    usetex = mpl.rcParams['text.usetex']
    linewidth = mpl.rcParams['lines.linewidth']

    # test context given dictionary
    with mpl.rc_context(rc={'text.usetex': not usetex}):
        assert mpl.rcParams['text.usetex'] == (not usetex)
    assert mpl.rcParams['text.usetex'] == usetex

    # test context given filename (mpl.rc sets linewdith to 33)
    with mpl.rc_context(fname=fname):
        assert mpl.rcParams['lines.linewidth'] == 33
    assert mpl.rcParams['lines.linewidth'] == linewidth

    # test context given filename and dictionary
    with mpl.rc_context(fname=fname, rc={'lines.linewidth': 44}):
        assert mpl.rcParams['lines.linewidth'] == 44
    assert mpl.rcParams['lines.linewidth'] == linewidth

    # test rc_file
    try:
        mpl.rc_file(fname)
        assert mpl.rcParams['lines.linewidth'] == 33
    finally:
        mpl.rcParams['lines.linewidth'] = linewidth
Beispiel #4
0
def test_Bug_2543():
    # Test that it possible to add all values to itself / deepcopy
    # This was not possible because validate_bool_maybe_none did not
    # accept None as an argument.
    # https://github.com/matplotlib/matplotlib/issues/2543
    # We filter warnings at this stage since a number of them are raised
    # for deprecated rcparams as they should. We don't want these in the
    # printed in the test suite.
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore',
                                category=MatplotlibDeprecationWarning)
        with mpl.rc_context():
            _copy = mpl.rcParams.copy()
            for key in _copy:
                mpl.rcParams[key] = _copy[key]
        with mpl.rc_context():
            _deep_copy = copy.deepcopy(mpl.rcParams)
        # real test is that this does not raise
        assert validate_bool_maybe_none(None) is None
        assert validate_bool_maybe_none("none") is None

    with pytest.raises(ValueError):
        validate_bool_maybe_none("blah")
    with pytest.raises(ValueError):
        validate_bool(None)
    with pytest.raises(ValueError):
        with mpl.rc_context():
            mpl.rcParams['svg.fonttype'] = True
Beispiel #5
0
 def test_to_string(self):
     # test without latex
     with rc_context(rc={'text.usetex': False}):
         assert to_string('test') == 'test'
         assert to_string(4.0) == '4.0'
         assert to_string(8) == '8'
     with rc_context(rc={'text.usetex': True}):
         assert to_string('test') == 'test'
         assert to_string(2000) == r'2\!\!\times\!\!10^{3}'
         assert to_string(8) == '8'
Beispiel #6
0
def test_rcparams_reset_after_fail():

    # There was previously a bug that meant that if rc_context failed and
    # raised an exception due to issues in the supplied rc parameters, the
    # global rc parameters were left in a modified state.

    with mpl.rc_context(rc={'text.usetex': False}):

        assert mpl.rcParams['text.usetex'] is False

        with assert_raises(KeyError):
            with mpl.rc_context(rc=OrderedDict([('text.usetex', True),('test.blah', True)])):
                pass

        assert mpl.rcParams['text.usetex'] is False
def test_jpeg_alpha():
    Image = pytest.importorskip('PIL.Image')

    plt.figure(figsize=(1, 1), dpi=300)
    # Create an image that is all black, with a gradient from 0-1 in
    # the alpha channel from left to right.
    im = np.zeros((300, 300, 4), dtype=float)
    im[..., 3] = np.linspace(0.0, 1.0, 300)

    plt.figimage(im)

    buff = io.BytesIO()
    with rc_context({'savefig.facecolor': 'red'}):
        plt.savefig(buff, transparent=True, format='jpg', dpi=300)

    buff.seek(0)
    image = Image.open(buff)

    # If this fails, there will be only one color (all black). If this
    # is working, we should have all 256 shades of grey represented.
    num_colors = len(image.getcolors(256))
    assert 175 <= num_colors <= 185
    # The fully transparent part should be red.
    corner_pixel = image.getpixel((0, 0))
    assert corner_pixel == (254, 0, 0)
Beispiel #8
0
 def style(self):
     if self.rc_context is None:
         return EmptyContext()
     elif type(self.rc_context) is dict:
         return mpl.rc_context(self.rc_context)
     else:
         return self.rc_context
Beispiel #9
0
    def plot_phase(self, rcParams={}):
        if not hasattr(self, 'phi_lock_phase'):
            raise ValueError("set_filters before plotting phase")

        rcParams_ = self.rcParams

        rcParams_.update(rcParams)

        with mpl.rc_context(rcParams_):
            fig, ax = plt.subplots()
            phase = np.angle(self.phi_lock)
            ax.plot(self.tm, phase[self.m])
            if hasattr(self, 'phase'):
                adjusted_phase = phase_err(self.phase - phase)
                ax.plot(self.tm, (phase + adjusted_phase)[self.m])
            if hasattr(self, 'm_fit'):
                ax.axvspan(np.min(self.t[self.m_fit]),
                           np.max(self.t[self.m_fit]),
                           alpha=0.5,
                           color='0.5')

            ax.set_xlabel('Time [s]')
            ax.set_ylabel('Phase [rad.]')

        return fig, ax
Beispiel #10
0
def plot_simple(x, y, magic=None, scale='linear', xlim=None, ylim=None,
                xlabel=None, ylabel=None, figax=None,
                rcParams={'backend': 'Qt4Agg'}, **plot_kwargs):
    if figax is not None:
        fig, ax = figax
    else:
        with mpl.rc_context(rc=rcParams):
            fig = plt.figure()
            ax = fig.add_subplot(111)

    plotting_functions = {
        'linear': ax.plot,
        'semilogy': ax.semilogy,
        'semilogx': ax.semilogx,
        'loglog': ax.loglog}

    if magic is None:
        plotting_functions[scale](x, y, **plot_kwargs)
    else:
        plotting_functions[scale](x, y, magic, **plot_kwargs)

    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)

    if xlim is not None:
        ax.set_xlim(xlim)
    if ylim is not None:
        ax.set_ylim(ylim)

    return fig, ax
Beispiel #11
0
    def plot_roc(self):
        """
        Plot receiver operating charactistic curve for this subject's classifier.
        """

        fpr, tpr, _ = roc_curve(self.res['y'], self.res['probs'])
        with plt.style.context('fivethirtyeight'):
            with mpl.rc_context({'ytick.labelsize': 16,
                                 'xtick.labelsize': 16}):
                plt.plot(fpr, tpr, lw=4, label='ROC curve (AUC = %0.2f)' % self.res['auc'])
                plt.plot([0, 1], [0, 1], color='k', lw=2, linestyle='--', label='_nolegend_')
                plt.xlim([0.0, 1.0])
                plt.ylim([0.0, 1.05])
                plt.xlabel('False Positive Rate', fontsize=24)
                plt.ylabel('True Positive Rate', fontsize=24)
                plt.legend(loc="lower right")
                if 'p_val' not in self.res:
                    title = 'ROC (AUC: {0:.3f})'.format(self.res['auc'])
                else:
                    p = self.res['p_val']
                    if p == 0:
                        p_str = '< {0:.2f}'.format(1 / self.num_iters)
                    else:
                        p_str = '= {0:.3f}'.format(p)
                    title = 'ROC (AUC: {0:.3f}, p{1})'.format(self.res['auc'], p_str)
                plt.title(title)
        plt.gcf().set_size_inches(12, 9)
Beispiel #12
0
def demo(ax, rcparams, title):
    np.random.seed(2)
    A = np.random.rand(5, 5)

    with mpl.rc_context(rc=rcparams):
        ax.imshow(A)
        ax.set_title(title)
Beispiel #13
0
    def __init__(self, port, baud, check = False):
        """
            Initialize the main display: a single figure
            :port: serial port index or name (eg.: COM4)
            :paud: baud rate (eg.: 115200)
        """
        self.frame = IM_Frame()
        self.new_frame = False
        self.lock = threading.Lock()
        self.check = check

        # disable figure toolbar, bind close event and open serial port
        with mpl.rc_context({'toolbar':False}):
            self.fig = plt.figure()
            self.serial_port = serial.Serial(port, baud, timeout=0.25,bytesize=serial.EIGHTBITS, parity=serial.PARITY_NONE, stopbits=serial.STOPBITS_ONE, xonxoff=False)
            self.serial_port.flush()
            self.fig.canvas.mpl_connect('close_event', self.close_display)
            self.fig.canvas.mpl_connect('resize_event', self.resize_display)
            self.fig.canvas.set_window_title('pyIM')

            # window position is set before state to maximized in update_graph
            w,h=getVirtualScreenSize()
            plt.get_current_fig_manager().window.wm_geometry(("+%d+%d"%(w-1,0)))

            # create timer for updating display using a rs232 polling callback
            timer = self.fig.canvas.new_timer(interval=250)
            timer.add_callback(self.update_graphs) # then arg if needed
            timer.start()

            # launch figure
            plt.show(block = True)

            # close correctly serial port
            self.close_display()
Beispiel #14
0
 def test_hist(self, table):
     with rc_context(rc={'text.usetex': False}):
         plot = table.hist('snr')
         assert isinstance(plot, HistogramPlot)
         assert len(plot.gca().patches) == 10
         with tempfile.NamedTemporaryFile(suffix='.png') as f:
             plot.save(f.name)
Beispiel #15
0
def plot_phasekick_corrected(df, extras, figax=None, rcParams={}, filename=None):
    if figax is None:
        with mpl.rc_context(rcParams):
            fig, ax = plt.subplots()
    else:
        fig, ax = figax

    data = df.loc['data']
    control = df.loc['control']

    if abs(data['dphi_corrected [cyc]']).max() > 0.15:
        units = 'cyc'
        scale = 1
    else:
        units = 'mcyc'
        scale = 1e3


    ax.plot(control.tp*1e3, control['dphi_corrected [cyc]']*scale, 'g.')
    ax.plot(data.tp*1e3, data['dphi_corrected [cyc]']*scale, 'b.')
    ax.plot(data.tp*1e3, phase_step(data.tp, *extras['popt_phase_corr'])*scale)

    ax.set_xlabel('Pulse time [ms]')
    ax.set_ylabel('Phase shift [{}.]'.format(units))
    if filename is not None:
        fig.savefig(filename, bbox_inches='tight')
    return fig, ax
Beispiel #16
0
def test_LogFormatterSciNotation():
    test_cases = {
        10: (
             (-1, '${-10^{0}}$'),
             (1e-05, '${10^{-5}}$'),
             (1, '${10^{0}}$'),
             (100000, '${10^{5}}$'),
             (2e-05, '${2\\times10^{-5}}$'),
             (2, '${2\\times10^{0}}$'),
             (200000, '${2\\times10^{5}}$'),
             (5e-05, '${5\\times10^{-5}}$'),
             (5, '${5\\times10^{0}}$'),
             (500000, '${5\\times10^{5}}$'),
        ),
        2: (
            (0.03125, '${2^{-5}}$'),
            (1, '${2^{0}}$'),
            (32, '${2^{5}}$'),
            (0.0375, '${1.2\\times2^{-5}}$'),
            (1.2, '${1.2\\times2^{0}}$'),
            (38.4, '${1.2\\times2^{5}}$'),
        )
    }

    for base in test_cases.keys():
        formatter = mticker.LogFormatterSciNotation(base=base)
        formatter.sublabel = set([1, 2, 5, 1.2])
        for value, expected in test_cases[base]:
            with matplotlib.rc_context({'text.usetex': False}):
                assert formatter(value) == expected
    def __init__(self, data, layout=(1, 1)):
        self._data = data
        self._layout = layout

        self._number_of_plots = len(data)
        self._plots_per_page = _numpy.product(layout)
        self._number_of_pages = (self._number_of_plots-1) / self._plots_per_page + 1
        self._current_page = 0

        with _matplotlib.rc_context({"toolbar": "None"}):
            self._fig = _matplotlib.pyplot.figure("Browser")
            self._fig.clear()
            #_matplotlib.pyplot.show()
        self._fig.set_facecolor("white")
        self._fig.subplots_adjust(left=0., bottom=0.1, right=1., top=0.95, wspace=0., hspace=0.05)

        self._axes = [self._fig.add_subplot(layout[0], layout[1], 1+i0*layout[1]+i1) for i0, i1 in _itertools.product(xrange(layout[0]), xrange(layout[1]))]

        ax_prev = self._fig.add_axes([0.1, 0.01, 0.3, 0.09])
        ax_next = self._fig.add_axes([0.6, 0.01, 0.3, 0.09])

        self._button_prev = _matplotlib.widgets.Button(ax_prev, "Previous")
        self._button_next = _matplotlib.widgets.Button(ax_next, "Next")
        self._button_prev.on_clicked(self._prev)
        self._button_next.on_clicked(self._next)

        self._page_text = self._fig.text(0.5, 0.05, "", va="center", ha="center")
        self._update_page_string()
        self._first_plot()
Beispiel #18
0
def context(style, after_reset=False):
    """Context manager for using style settings temporarily.

    Parameters
    ----------
    style : str, dict, or list
        A style specification. Valid options are:

        +------+-------------------------------------------------------------+
        | str  | The name of a style or a path/URL to a style file. For a    |
        |      | list of available style names, see `style.available`.       |
        +------+-------------------------------------------------------------+
        | dict | Dictionary with valid key/value pairs for                   |
        |      | `matplotlib.rcParams`.                                      |
        +------+-------------------------------------------------------------+
        | list | A list of style specifiers (str or dict) applied from first |
        |      | to last in the list.                                        |
        +------+-------------------------------------------------------------+

    after_reset : bool
        If True, apply style after resetting settings to their defaults;
        otherwise, apply style on top of the current settings.
    """
    with mpl.rc_context():
        if after_reset:
            mpl.rcdefaults()
        use(style)
        yield
Beispiel #19
0
 def test_view_limits(self):
     """
     Test basic behavior of view limits.
     """
     with matplotlib.rc_context({'axes.autolimit_mode': 'data'}):
         loc = mticker.MultipleLocator(base=3.147)
         assert_almost_equal(loc.view_limits(-5, 5), (-5, 5))
Beispiel #20
0
    def _init_axis(self, fig, axis):
        """
        Return an axis which may need to be initialized from
        a new figure.
        """
        if not fig and self._create_fig:
            rc_params = self.fig_rcparams
            if self.fig_latex:
                rc_params['text.usetex'] = True
            with mpl.rc_context(rc=rc_params):
                fig = plt.figure()
                l, b, r, t = self.fig_bounds
                inches = self.fig_inches
                fig.subplots_adjust(left=l, bottom=b, right=r, top=t)
                fig.patch.set_alpha(self.fig_alpha)
                if isinstance(inches, (tuple, list)):
                    inches = list(inches)
                    if inches[0] is None:
                        inches[0] = inches[1]
                    elif inches[1] is None:
                        inches[1] = inches[0]
                    fig.set_size_inches(list(inches))
                else:
                    fig.set_size_inches([inches, inches])
                axis = fig.add_subplot(111, projection=self.projection)
                axis.set_aspect('auto')

        return fig, axis
Beispiel #21
0
def plot_dA_dphi_vs_t(df, extras, figax=None, rcParams={}, filename=None):
    if figax is None:
        with mpl.rc_context(rcParams):
            fig, (ax1, ax2) = plt.subplots(nrows=2, sharex=True)
    else:
        fig, (ax1, ax2) = figax

    df_sorted = df.sort_values('phi_at_tp [rad]')

    control = df_sorted.loc['control']
    f0 = control['f0 [Hz]'].median()

    td = control['phi_at_tp [rad]']/(2*np.pi*f0) * 1e6
    ax1.plot(td, control['dA [nm]'], 'b.')
    ax1.plot(td, offset_cos(control['phi_at_tp [rad]'], *extras['popt_A']))

    ax2.plot(td, control['dphi_tp_end [cyc]']*1e3, 'b.')
    ax2.plot(td, offset_cos(control['phi_at_tp [rad]'], *extras['popt_phi'])*1e3)

    ax1.grid()
    ax2.grid()
    ax1.set_xlabel(r'$\tau_\mathrm{d} \; [\mu\mathrm{s}]$')
    ax1.set_ylabel(r'$\Delta A \; [\mathrm{nm}]$')
    ax2.set_ylabel(r'$\Delta \phi \; [\mathrm{mcyc.}]$')

    if filename is not None:
        fig.savefig(filename, bbox_inches='tight')
    return fig, (ax1, ax2)
Beispiel #22
0
def test_embed_limit(method_name, caplog):
    with mpl.rc_context({"animation.embed_limit": 1e-6}):  # ~1 byte.
        getattr(make_animation(frames=1), method_name)()
    assert len(caplog.records) == 1
    record, = caplog.records
    assert (record.name == "matplotlib.animation"
            and record.levelname == "WARNING")
Beispiel #23
0
 def test_plot(self, instance):
     with rc_context(rc={'text.usetex': False}):
         plot = instance.plot(figsize=(6.4, 3.8))
         assert isinstance(plot, SegmentPlot)
         assert isinstance(plot.gca(), SegmentAxes)
         with tempfile.NamedTemporaryFile(suffix='.png') as f:
             plot.save(f.name)
Beispiel #24
0
def test_colorbar_closed_patch():
    fig = plt.figure(figsize=(8, 6))
    ax1 = fig.add_axes([0.05, 0.85, 0.9, 0.1])
    ax2 = fig.add_axes([0.1, 0.65, 0.75, 0.1])
    ax3 = fig.add_axes([0.05, 0.45, 0.9, 0.1])
    ax4 = fig.add_axes([0.05, 0.25, 0.9, 0.1])
    ax5 = fig.add_axes([0.05, 0.05, 0.9, 0.1])

    cmap = cm.get_cmap("RdBu", lut=5)

    im = ax1.pcolormesh(np.linspace(0, 10, 16).reshape((4, 4)), cmap=cmap)

    # The use of a "values" kwarg here is unusual.  It works only
    # because it is matched to the data range in the image and to
    # the number of colors in the LUT.
    values = np.linspace(0, 10, 5)
    cbar_kw = dict(cmap=cmap, orientation='horizontal', values=values,
                   ticks=[])

    # The wide line is to show that the closed path is being handled
    # correctly.  See PR #4186.
    with rc_context({'axes.linewidth': 16}):
        plt.colorbar(im, cax=ax2, extend='both', extendfrac=0.5, **cbar_kw)
        plt.colorbar(im, cax=ax3, extend='both', **cbar_kw)
        plt.colorbar(im, cax=ax4, extend='both', extendrect=True, **cbar_kw)
        plt.colorbar(im, cax=ax5, extend='neither', **cbar_kw)
def plot_weight_comparisons(gd_file, mal_file,
                            malicious_behaviour="Selfish", s="bella_static",
                            excluded=None, show_title=True, figsize=None,
                            labels=None, prefix="img/", extension="pdf"):
    if labels is None:
        labels = ["Fair", "Selfish"]

    if excluded is None:
        excluded = []

    mtfm_args = ('n0', 'n1', ['n2', 'n3'], ['n4', 'n5'])
    with mpl.rc_context(rc={'text.usetex': 'True'}):

        gd_trust, mal_trust = per_scenario_gd_mal_trusts(gd_file, mal_file)
        gd_tp, mal_tp = per_scenario_gd_mal_trust_perspective(gd_trust, mal_trust, s=s)

        gd_mtfm = Trust.generate_mtfm(gd_tp, *mtfm_args).sum(axis=1)
        mal_mtfm = Trust.generate_mtfm(mal_tp, *mtfm_args).sum(axis=1)

        plot_comparison(gd_mtfm, mal_mtfm, s=s,
                        show_title=show_title, keyword=malicious_behaviour,
                        figsize=figsize, labels=labels, prefix=prefix, extension=extension, show_grid=False)
        for i, mi in enumerate(trust_metrics):
            if mi not in excluded:
                gd_tp, mal_tp = per_scenario_gd_mal_trust_perspective(gd_trust, mal_trust, s=s,
                                                                      weight_vector=weight_for_metric(mi, 3))

                gd_mtfm = Trust.generate_mtfm(gd_tp, *mtfm_args).sum(axis=1)
                mal_mtfm = Trust.generate_mtfm(mal_tp, *mtfm_args).sum(axis=1)
                plot_comparison(gd_mtfm, mal_mtfm, s, metric=mi, show_title=show_title, prefix=prefix,
                                keyword=malicious_behaviour, figsize=figsize, labels=labels,
                                extension=extension, show_grid=False)
Beispiel #26
0
    def __init__(self, layout, axis=None, create_axes=True, ranges=None,
                 layout_num=1, keys=None, **params):
        if not isinstance(layout, GridSpace):
            raise Exception("GridPlot only accepts GridSpace.")
        super(GridPlot, self).__init__(layout, layout_num=layout_num,
                                       ranges=ranges, keys=keys, **params)
        # Compute ranges layoutwise
        grid_kwargs = {}
        if axis is not None:
            bbox = axis.get_position()
            l, b, w, h = bbox.x0, bbox.y0, bbox.width, bbox.height
            grid_kwargs = {'left': l, 'right': l+w, 'bottom': b, 'top': b+h}
            self.position = (l, b, w, h)

        self.cols, self.rows = layout.shape
        self.fig_inches = self._get_size()
        self._layoutspec = gridspec.GridSpec(self.rows, self.cols, **grid_kwargs)

        with mpl.rc_context(rc=self.fig_rcparams):
            self.subplots, self.subaxes, self.layout = self._create_subplots(layout, axis,
                                                                             ranges, create_axes)
        if self.top_level:
            self.comm = self.init_comm()
            self.traverse(lambda x: setattr(x, 'comm', self.comm))
            self.traverse(lambda x: attach_streams(self, x.hmap, 2),
                          [GenericElementPlot])
def test_font_priority():
    with rc_context(rc={
            'font.sans-serif':
            ['cmmi10', 'Bitstream Vera Sans']}):
        font = findfont(
            FontProperties(family=["sans-serif"]))
    assert_equal(os.path.basename(font), 'cmmi10.ttf')
Beispiel #28
0
def test_colorbar_get_ticks():
    with rc_context({'_internal.classic_mode': False}):

        fig, ax = plt.subplots()
        np.random.seed(19680801)
        pc = ax.pcolormesh(np.random.rand(30, 30))
        cb = fig.colorbar(pc)
        np.testing.assert_allclose(cb.get_ticks(), [0.2, 0.4, 0.6, 0.8])
Beispiel #29
0
 def test_plot(self, array):
     with rc_context(rc={'text.usetex': False}):
         plot = array.plot()
         line = plot.gca().lines[0]
         utils.assert_array_equal(line.get_xdata(), array.xindex.value)
         utils.assert_array_equal(line.get_ydata(), array.value)
         plot.save(BytesIO(), format='png')
         plot.close()
Beispiel #30
0
def test_patheffects():
    with matplotlib.rc_context():
        matplotlib.rcParams['path.effects'] = [
            patheffects.withStroke(linewidth=4, foreground='w')]
        fig, ax = plt.subplots()
        ax.plot([1, 2, 3])
        with io.BytesIO() as ps:
            fig.savefig(ps, format='ps')
Beispiel #31
0
def test_jpeg_alpha():
    plt.figure(figsize=(1, 1), dpi=300)
    # Create an image that is all black, with a gradient from 0-1 in
    # the alpha channel from left to right.
    im = np.zeros((300, 300, 4), dtype=float)
    im[..., 3] = np.linspace(0.0, 1.0, 300)

    plt.figimage(im)

    buff = io.BytesIO()
    with rc_context({'savefig.facecolor': 'red'}):
        plt.savefig(buff, transparent=True, format='jpg', dpi=300)

    buff.seek(0)
    image = Image.open(buff)

    # If this fails, there will be only one color (all black). If this
    # is working, we should have all 256 shades of grey represented.
    num_colors = len(image.getcolors(256))
    assert 175 <= num_colors <= 185
    # The fully transparent part should be red.
    corner_pixel = image.getpixel((0, 0))
    assert corner_pixel == (254, 0, 0)
Beispiel #32
0
    def make(self):
        "Constructs the plot using the methods. This is the 'main' for ggplot"
        plt.close()
        with mpl.rc_context():
            self.apply_theme()

            if self.facets:
                self.fig, self.subplots = self.make_facets()
            else:
                subplot_kw = {}
                if self.coords == "polar":
                    subplot_kw = {"polar": True}
                self.fig, self.subplots = plt.subplots(subplot_kw=subplot_kw)

            self.apply_scales()

            legend, groups = self._construct_plot_data()
            self._aes.legend = legend
            for _, group in groups:
                for ax, facetgroup in self.get_facet_groups(group):
                    for layer in self.layers:
                        kwargs = self._prep_layer_for_plotting(
                            layer, facetgroup)
                        if kwargs == False:
                            continue
                        layer.plot(ax, facetgroup, self._aes, **kwargs)

            self.apply_limits()
            self.add_labels()
            self.apply_axis_scales()
            self.apply_axis_labels()
            self.apply_coords()
            self.add_legend(legend)

            if self.theme:
                for ax in self._iterate_subplots():
                    self.theme.apply_final_touches(ax)
Beispiel #33
0
    def plot_roc(self):
        """
        Plot receiver operating charactistic curve for this subject's classifier.
        """

        fpr, tpr, _ = roc_curve(self.res['y'], self.res['probs'])
        with plt.style.context('fivethirtyeight'):
            with mpl.rc_context({
                    'ytick.labelsize': 16,
                    'xtick.labelsize': 16
            }):
                plt.plot(fpr,
                         tpr,
                         lw=4,
                         label='ROC curve (AUC = %0.2f)' % self.res['auc'])
                plt.plot([0, 1], [0, 1],
                         color='k',
                         lw=2,
                         linestyle='--',
                         label='_nolegend_')
                plt.xlim([0.0, 1.0])
                plt.ylim([0.0, 1.05])
                plt.xlabel('False Positive Rate', fontsize=24)
                plt.ylabel('True Positive Rate', fontsize=24)
                plt.legend(loc="lower right")
                if 'p_val' not in self.res:
                    title = 'ROC (AUC: {0:.3f})'.format(self.res['auc'])
                else:
                    p = self.res['p_val']
                    if p == 0:
                        p_str = '< {0:.2f}'.format(1 / self.num_iters)
                    else:
                        p_str = '= {0:.3f}'.format(p)
                    title = 'ROC (AUC: {0:.3f}, p{1})'.format(
                        self.res['auc'], p_str)
                plt.title(title)
        plt.gcf().set_size_inches(12, 9)
Beispiel #34
0
    def _draw(self, return_ggplot=False):
        # Prevent against any modifications to the users
        # ggplot object. Do the copy here as we may/may not
        # assign a default theme
        self = deepcopy(self)
        self._build()

        # If no theme we use the default
        self.theme = self.theme or theme_get()

        try:
            with mpl.rc_context():
                # setup & rcparams theming
                self.theme.apply_rcparams()
                figure, axs = self._create_figure()
                self._setup_parameters()
                self._resize_panels()
                # Drawing
                self._draw_layers()
                self._draw_labels()
                self._draw_breaks_and_labels()
                self._draw_legend()
                self._draw_title()
                self._draw_watermarks()
                # Artist object theming
                self._apply_theme()  # !!
        except Exception as err:
            if self.figure is not None:
                plt.close(self.figure)
            raise err

        if return_ggplot:
            output = self.figure, self
        else:
            output = self.figure

        return output
Beispiel #35
0
def main() -> None:
    with mpl.rc_context(rc=DEFAULT_RC):
        fig, axes = plt.subplots(ncols=3, nrows=2, dpi=300, sharey=False)
        idx = 0
        for dataset in DATASET_NAMES:
            datafile = DATAFILE_LIST[dataset]
            num_classes = NUM_CLASSES_DICT[dataset]

            categories, observations, confidences, idx2category, category2idx, labels = prepare_data(datafile, False)

            # accuracy models
            accuracy_model = BetaBernoulli(k=num_classes, prior=None)
            accuracy_model.update_batch(categories, observations)

            # ece models for each class
            ece_model = ClasswiseEce(num_classes, num_bins=10, pseudocount=2)
            ece_model.update_batch(categories, observations, confidences)

            # draw samples from posterior of classwise accuracy
            accuracy_samples = accuracy_model.sample(num_samples)  # (num_categories, num_samples)
            ece_samples = ece_model.sample(num_samples)  # (num_categories, num_samples)

            plot_kwargs = {}
            axes[idx // 3, idx % 3] = plot_scatter(axes[idx // 3, idx % 3], accuracy_samples, ece_samples,
                                                   limit=TOPK_DICT[dataset], plot_kwargs=plot_kwargs)
            axes[idx // 3, idx % 3].set_title(DATASET_NAMES[dataset])
            idx += 1

    axes[0, 0].set_ylabel('ECE')
    axes[1, 0].set_ylabel('ECE')
    fig.set_size_inches(TEXT_WIDTH, 4.0)
    fig.subplots_adjust(bottom=0.05, wspace=0.2)
    fig.delaxes(axes.flatten()[5])
    figname = FIGURE_DIR + 'scatter.pdf'
    fig.tight_layout()
    fig.savefig(figname, bbox_inches='tight', pad_inches=0)
Beispiel #36
0
    def update_depth_limits(self):

        with rc_context(self.rc_context):
            if self.lib.has_ssp():

                if len(self.lib.cur.proc.depth[self.vi]) > 0:

                    max_proc_depth = self.lib.cur.proc.depth[self.vi].max()
                    mean_sis_depth = 0
                    if self.lib.use_sis():
                        if self.lib.listeners.sis.xyz:
                            if self.lib.listeners.sis.xyz_mean_depth:
                                mean_sis_depth = self.lib.listeners.sis.xyz_mean_depth
                    max_proc_sis_depth = max(max_proc_depth, mean_sis_depth)

                    max_depth = max(30. + max_proc_sis_depth,
                                    1.1 * max_proc_sis_depth)
                    min_depth = -0.05 * max_proc_sis_depth
                    if min_depth > 0:
                        min_depth = -5

                    self.speed_ax.set_ylim(bottom=max_depth, top=min_depth)

            self.c.draw()
Beispiel #37
0
    def __init__(self, layout, axis=None, create_axes=True, ranges=None,
                 layout_num=1, keys=None, **params):
        if not isinstance(layout, GridSpace):
            raise Exception("GridPlot only accepts GridSpace.")
        super(GridPlot, self).__init__(layout, layout_num=layout_num,
                                       ranges=ranges, keys=keys, **params)
        # Compute ranges layoutwise
        grid_kwargs = {}
        if axis is not None:
            bbox = axis.get_position()
            l, b, w, h = bbox.x0, bbox.y0, bbox.width, bbox.height
            grid_kwargs = {'left': l, 'right': l+w, 'bottom': b, 'top': b+h}
            self.position = (l, b, w, h)

        self.cols, self.rows = layout.shape
        self.fig_inches = self._get_size()
        self._layoutspec = gridspec.GridSpec(self.rows, self.cols, **grid_kwargs)

        with mpl.rc_context(rc=self.fig_rcparams):
            self.subplots, self.subaxes, self.layout = self._create_subplots(layout, axis,
                                                                             ranges, create_axes)
        if self.top_level:
            self.traverse(lambda x: attach_streams(self, x.hmap, 2),
                          [GenericElementPlot])
Beispiel #38
0
    def test_rcparams(self):

        # Test custom rcParams

        with rc_context({
                'axes.labelcolor': 'purple',
                'axes.labelsize': 14,
                'axes.labelweight': 'bold',
                'axes.linewidth': 3,
                'axes.facecolor': '0.5',
                'axes.edgecolor': 'green',
                'xtick.color': 'red',
                'xtick.labelsize': 8,
                'xtick.direction': 'in',
                'xtick.minor.visible': True,
                'xtick.minor.size': 5,
                'xtick.major.size': 20,
                'xtick.major.width': 3,
                'xtick.major.pad': 10,
                'grid.color': 'blue',
                'grid.linestyle': ':',
                'grid.linewidth': 1,
                'grid.alpha': 0.5
        }):

            fig = plt.figure(figsize=(6, 6))
            ax = WCSAxes(fig, [0.15, 0.1, 0.7, 0.7], wcs=None)
            fig.add_axes(ax)
            ax.set_xlim(-0.5, 2)
            ax.set_ylim(-0.5, 2)
            ax.grid()
            ax.set_xlabel('X label')
            ax.set_ylabel('Y label')
            ax.coords[0].set_ticklabel(exclude_overlapping=True)
            ax.coords[1].set_ticklabel(exclude_overlapping=True)
            return fig
Beispiel #39
0
    def _draw_using_figure(self, figure, axs):
        """
        Draw onto already created figure and axes

        This is can be used to draw animation frames,
        or inset plots. It is intended to be used
        after the key plot has been drawn.

        Parameters
        ----------
        figure : matplotlib.figure.Figure
            Matplotlib figure
        axs : array_like
            Array of Axes onto which to draw the plots
        """
        self = deepcopy(self)
        self._build()

        self.theme = self.theme or theme_get()
        self.figure = figure
        self.axs = axs

        try:
            with mpl.rc_context():
                self.theme.apply_rcparams()
                self._setup_parameters()
                self._draw_layers()
                self._draw_facet_labels()
                self._draw_legend()
                self._apply_theme()
        except Exception as err:
            if self.figure is not None:
                plt.close(self.figure)
            raise err

        return self
Beispiel #40
0
def test_colorbar_autotickslog():
    # Test new autotick modes...
    with rc_context({'_internal.classic_mode': False}):
        fig, ax = plt.subplots(2, 1)
        x = np.arange(-3.0, 4.001)
        y = np.arange(-4.0, 3.001)
        X, Y = np.meshgrid(x, y)
        Z = X * Y
        pcm = ax[0].pcolormesh(X, Y, 10**Z, norm=LogNorm())
        cbar = fig.colorbar(pcm,
                            ax=ax[0],
                            extend='both',
                            orientation='vertical')

        pcm = ax[1].pcolormesh(X, Y, 10**Z, norm=LogNorm())
        cbar2 = fig.colorbar(pcm,
                             ax=ax[1],
                             extend='both',
                             orientation='vertical',
                             shrink=0.4)
        np.testing.assert_almost_equal(cbar.ax.yaxis.get_ticklocs(),
                                       10**np.arange(-12, 12.2, 4.))
        np.testing.assert_almost_equal(cbar2.ax.yaxis.get_ticklocs(),
                                       10**np.arange(-12, 13., 12.))
Beispiel #41
0
    def save(self,
             path,
             format: Optional[str] = None,
             dpi: Optional[float] = 150):
        """Saves the figure to a file.

        Parameters
        ----------
            path: ``tuple`` of ``float``, optional
                Path in which to store the file.

            format: ``str``, optional
                File format, e.g. ``'png'``, ``'pdf'``, ``'svg'``. If not
                provided, the output format is inferred from the extension of
                ``path``.

            dpi: ``float``, optional
                Resolution in dots per inch. If not provided, defaults to
                ``150``.

        """
        with mpl.rc_context(fname=self.fname, rc=self.rc):
            self.fig.tight_layout()
            self.fig.savefig(path, dpi=dpi, bbox_inches="tight")
Beispiel #42
0
    def plot_values(self,
                    ids,
                    model,
                    ymax,
                    ymin,
                    dpi=300,
                    sub=None,
                    subagg=None,
                    cells=None,
                    pairwise=False,
                    colors=None,
                    prefix=None,
                    w=None,
                    filter=None,
                    legend=False):
        """Plot values in cluster

        Parameters
        ----------
        ids : sequence | dict | scalar <= 1
            IDs of the clusters that should be plotted. For ANOVA results, this
            should be an ``{effect_name: id_list}`` dict. Instead of a list of
            IDs a scalar can be provided to plot all clusters with p-values
            smaller than this.
        model : str
            Model defining cells which to plot separately.
        ymax : scalar
            Top of the y-axis.
        ymin : scalar
            Bottom of the y axis.
        dpi : int
            Figure DPI.
        sub : str
            Only use a subset of the data.
        subagg : str
           Index in ds: within index, collapse across other predictors.
        cells : sequence of cells in model
            Modify visible cells and their order. Only applies to the barplot.
            Does not affect filename.
        pairwise : bool
            Add pairwise tests to barplots.
        colors : dict
            Substitute colors (default are the colors provided at
            initialization).
        prefix : str
            Prefix to use for the image files (optional, can be used to
            distinguish different groups of images sharing the same color-bars).
        w : scalar
            UTS-stat plot width (default is ``2 * h``).
        filter : Filter
            Filter signal for display purposes (optional).
        legend : bool
            Plot a color legend.
        """
        if w is None:
            w = self.h * 2
        ds, model, modelname = self._get_data(model, sub, subagg)
        ids = self._ids(ids)
        if colors is None:
            colors = self.colors

        src = ds['srcm']
        n_cells = len(ds.eval(model).cells)
        w_bar = (n_cells * 2 + 4) * (self.h / 12)
        with mpl.rc_context(self.rc):
            for cid in ids:
                name = cname(cid)
                if prefix:
                    name = prefix + ' ' + name
                cluster = self._get_cluster(cid)
                y_mean = src.mean(cluster != 0)
                y_tc = src.mean(cluster.any('time'))

                # barplot
                p = plot.Barplot(y_mean,
                                 model,
                                 'subject',
                                 None,
                                 cells,
                                 pairwise,
                                 ds=ds,
                                 trend=False,
                                 corr=None,
                                 title=None,
                                 frame=False,
                                 yaxis=False,
                                 ylabel=False,
                                 colors=colors,
                                 bottom=ymin,
                                 top=ymax,
                                 w=w_bar,
                                 h=self.h,
                                 xlabel=None,
                                 xticks=None,
                                 tight=False,
                                 test_markers=False,
                                 show=False)
                p.save(self._dst.vec % ' '.join((name, modelname, 'barplot')),
                       dpi=dpi,
                       transparent=True)
                p.close()

                # time-course
                if filter is not None:
                    y_tc = filter.filtfilt(y_tc)
                p = plot.UTSStat(y_tc,
                                 model,
                                 match='subject',
                                 ds=ds,
                                 error='sem',
                                 colors=colors,
                                 title=None,
                                 axtitle=None,
                                 frame=False,
                                 bottom=ymin,
                                 top=ymax,
                                 legend=None,
                                 ylabel=None,
                                 xlabel=None,
                                 w=w,
                                 h=self.h,
                                 tight=False,
                                 show=False)
                dt = y_tc.time.tstep / 2.
                mark_start = cluster.info['tstart'] - dt
                mark_stop = cluster.info['tstop'] - dt
                p.add_vspan(mark_start,
                            mark_stop,
                            color='k',
                            alpha=0.1,
                            zorder=-2)
                p.save(self._dst.vec % ' '.join(
                    (name, modelname, 'timecourse')),
                       dpi=dpi,
                       transparent=True)
                p.close()

                # legend (only once)
                if legend:
                    p.save_legend(self._dst.vec % (modelname + ' legend'),
                                  transparent=True)
                    legend = False
Beispiel #43
0
    def plot_clusters_spatial(self, ids, views, w=600, h=480, prefix=''):
        """Plot spatial extent of the clusters

        Parameters
        ----------
        ids : sequence | dict | scalar <= 1
            IDs of the clusters that should be plotted. For ANOVA results, this
            should be an ``{effect_name: id_list}`` dict. Instead of a list of
            IDs a scalar can be provided to plot all clusters with p-values
            smaller than this.
        views : str | list of str | dict
            Can a str or list of str to use the same views for all clusters. A dict
            can have as keys labels or cluster IDs.
        w, h : int
            Size in pixels. The default (600 x 480) corresponds to 2 x 1.6 in
            at 300 dpi.
        prefix : str
            Prefix to use for the image files (optional, can be used to
            distinguish different groups of images sharing the same color-bars).

        Notes
        -----
        The horizontal colorbar is 1.5 in wide, the vertical colorbar is 1.6 in
        high.
        """
        ids = self._ids(ids)
        clusters = self._get_clusters(ids)
        clusters_spatial = [c.sum('time') for c in clusters]
        if isinstance(views, str):
            views = (views, )

        # vmax
        vmin = min(c.min() for c in clusters_spatial)
        vmax = max(c.max() for c in clusters_spatial)
        abs_vmax = max(vmax, abs(vmin))

        # anatomical extent
        brain_colorbar_done = False
        for cid, cluster in zip(ids, clusters_spatial):
            name = cname(cid)
            if prefix:
                name = prefix + ' ' + name

            for hemi in ('lh', 'rh'):
                if not cluster.sub(source=hemi).any():
                    continue
                brain = plot.brain.cluster(cluster,
                                           abs_vmax,
                                           views='lat',
                                           background=(1, 1, 1),
                                           colorbar=False,
                                           parallel=True,
                                           hemi=hemi,
                                           w=w,
                                           h=h)
                for view in views:
                    brain.show_view(view)
                    brain.save_image(
                        self._dst_pix % ' '.join((name, hemi, view)), 'rgba',
                        True)

                if not brain_colorbar_done:
                    with mpl.rc_context(self.rc):
                        label = "Sum of %s-values" % cluster.info['meas']
                        clipmin = 0 if vmin == 0 else None
                        clipmax = 0 if vmax == 0 else None
                        if prefix:
                            cbar_name = '%s cbar %%s' % prefix
                        else:
                            cbar_name = 'cbar %s'

                        h_cmap = 0.7 + POINT * mpl.rcParams['font.size']
                        p = brain.plot_colorbar(label,
                                                clipmin=clipmin,
                                                clipmax=clipmax,
                                                width=0.1,
                                                h=h_cmap,
                                                w=1.5,
                                                show=False)
                        p.save(self._dst.vec % cbar_name % 'h',
                               transparent=True)
                        p.close()

                        w_cmap = 0.8 + 0.1 * abs(floor(log10(vmax)))
                        p = brain.plot_colorbar(label,
                                                clipmin=clipmin,
                                                clipmax=clipmax,
                                                width=0.1,
                                                h=1.6,
                                                w=w_cmap,
                                                orientation='vertical',
                                                show=False)
                        p.save(self._dst.vec % cbar_name % 'v',
                               transparent=True)
                        p.close()

                        brain_colorbar_done = True

                brain.close()
Beispiel #44
0
def plot_classification_categorical(X,
                                    target_col,
                                    types=None,
                                    kind='auto',
                                    hue_order=None,
                                    **kwargs):
    """Plots for categorical features in classification.

    Creates plots of categorical variable distributions for each target class.
    Relevant features are identified via mutual information.

    For high cardinality categorical variables (variables with many categories)
    only the most frequent categories are shown.

    Parameters
    ----------
    X : dataframe
        Input data including features and target
    target_col : str or int
        Identifier of the target column in X
    types : dataframe of types, optional.
        Output of detect_types on X. Can be used to avoid recomputing the
        types.
    kind : string, default 'auto'
        Kind of plot to show. Options are 'count', 'proportion',
        'mosaic' and 'auto'.
        Count shows raw class counts within categories
        (can be hard to read with imbalanced classes)
        Proportion shows class proportions within categories
        (can be misleading with imbalanced categories)
        Mosaic shows both aspects, but can be a bit busy.
        Auto uses mosaic plots for binary classification and counts otherwise.

    """
    types = _check_X_target_col(X, target_col, types, task="classification")
    if kind == "auto":
        if X[target_col].nunique() > 5:
            kind = 'count'
        else:
            kind = 'mosaic'

    features = X.loc[:, types.categorical]
    if target_col in features.columns:
        features = features.drop(target_col, axis=1)

    if features.shape[1] == 0:
        return

    features = features.astype('category')

    show_top = _get_n_top(features, "categorical")

    # can't use OrdinalEncoder because we might have mix of int and string
    ordinal_encoded = features.apply(lambda x: x.cat.codes)
    target = X[target_col]
    f = mutual_info_classif(ordinal_encoded,
                            target,
                            discrete_features=np.ones(X.shape[1], dtype=bool))
    top_k = np.argsort(f)[-show_top:][::-1]
    # large number of categories -> taller plot
    row_height = 3 if features.nunique().max() <= 5 else 5
    fig, axes = _make_subplots(n_plots=show_top, row_height=row_height)
    plt.suptitle("Categorical Features vs Target", y=1.02)
    for i, (col_ind, ax) in enumerate(zip(top_k, axes.ravel())):
        col = features.columns[col_ind]
        if kind == 'proportion':
            X_new = _prune_category_make_X(X, col, target_col)

            df = (X_new.groupby(col)[target_col].value_counts(
                normalize=True).unstack().sort_values(by=target[0])
                  )  # hacky way to get a class name
            df.plot(kind='barh', stacked='True', ax=ax, legend=i == 0)
            ax.set_title(col)
            ax.set_ylabel(None)
        elif kind == 'mosaic':
            # how many categories make up at least 1% of data:
            n_cats = (X[col].value_counts() / len(X) > 0.01).sum()
            n_cats = np.minimum(n_cats, 20)
            X_new = _prune_category_make_X(X,
                                           col,
                                           target_col,
                                           max_categories=n_cats)
            mosaic_plot(X_new, col, target_col, ax=ax)
            ax.set_title(col)
        elif kind == 'count':
            X_new = _prune_category_make_X(X, col, target_col)

            # absolute counts
            # FIXME show f value
            # FIXME shorten titles?
            props = {}
            if X[target_col].nunique() > 15:
                props['font.size'] = 6
            with mpl.rc_context(props):
                sns.countplot(y=col,
                              data=X_new,
                              ax=ax,
                              hue=target_col,
                              hue_order=hue_order)
            if i > 0:
                ax.legend(())
        else:
            raise ValueError("Unknown plot kind {}".format(kind))
        _short_tick_names(ax)

    for j in range(i + 1, axes.size):
        # turn off axis if we didn't fill last row
        axes.ravel()[j].set_axis_off()
Beispiel #45
0
    def map_profiles(self, pks=None, output_folder=None, save_fig=False):
        """plot all the ssp in the database"""

        with rc_context(self.rc_context):

            if not save_fig:
                plt.ion()

            rows = self.db.list_profiles()
            if rows is None:
                raise RuntimeError(
                    "Unable to retrieve ssp view rows > Empty database?")
            if len(rows) == 0:
                raise RuntimeError(
                    "Unable to retrieve ssp view rows > Empty database?")

            # prepare the data
            ssp_x = list()
            ssp_y = list()
            ssp_label = list()
            for row in rows:

                if pks is not None:  # only if a pk-based filter was passed
                    if row[0] in pks:
                        ssp_x.append(row[2].x)
                        ssp_y.append(row[2].y)
                        ssp_label.append(row[0])

                else:
                    ssp_x.append(row[2].x)
                    ssp_y.append(row[2].y)
                    ssp_label.append(row[0])

            # make the world map
            plt.close("Profiles Map")
            fig = plt.figure("Profiles Map")
            # fig.patch.set_facecolor('#1464F4')
            ax = plt.subplot(111, aspect='equal')
            plt.subplots_adjust(left=0,
                                bottom=0,
                                right=1,
                                top=1,
                                wspace=0,
                                hspace=0)
            # plt.title("SSP Map (%s profiles)" % len(view_rows))
            plt.ioff()

            if rows:
                wm = self._world_draw_polygons()
            else:
                wm = self._world_draw_map()
            x, y = wm(ssp_x, ssp_y)
            wm.scatter(x, y, marker='o', s=16, color='r')
            wm.scatter(x, y, marker='.', s=1, color='k')
            if pks is not None:
                delta = 5.0
                y_min = min(y)
                if (y_min - delta) < -90.0:
                    y_min = -90.0
                else:
                    y_min -= delta
                y_max = max(y)
                if (y_max + delta) > 90.0:
                    y_max = 90.0
                else:
                    y_max += delta

                x_min = min(x)
                if (x_min - delta) < -180.0:
                    x_min = -180.0
                else:
                    x_min = x_min - delta
                x_max = max(x)
                if (x_max + delta) > 180.0:
                    x_max = 180.0
                else:
                    x_max += delta
                # logger.debug("%s %s, %s %s" % (y_min, y_max, x_min, x_max))
                ax.set_ylim(y_min, y_max)
                ax.set_xlim(x_min, x_max)

            if save_fig and (output_folder is not None):
                plt.savefig(os.path.join(self.plots_folder(output_folder),
                                         'ssp_map.png'),
                            bbox_inches='tight')
            # else:
            #     plt.show()

        return True
Beispiel #46
0
 def test_basic(self, xmax, decimals, symbol,
                x, display_range, expected):
     formatter = mticker.PercentFormatter(xmax, decimals, symbol)
     with matplotlib.rc_context(rc={'text.usetex': False}):
         assert formatter.format_pct(x, display_range) == expected
Beispiel #47
0
 def test_min_exponent(self, min_exponent, value, expected):
     with matplotlib.rc_context({'axes.formatter.min_exponent':
                                 min_exponent}):
         assert self.fmt(value) == expected
Beispiel #48
0
def test_auto_date_locator_intmult_tz():
    def _create_auto_date_locator(date1, date2, tz):
        locator = mdates.AutoDateLocator(interval_multiples=True, tz=tz)
        locator.create_dummy_axis()
        locator.set_view_interval(mdates.date2num(date1),
                                  mdates.date2num(date2))
        return locator

    results = ([
        datetime.timedelta(weeks=52 * 200),
        [
            '1980-01-01 00:00:00-08:00', '2000-01-01 00:00:00-08:00',
            '2020-01-01 00:00:00-08:00', '2040-01-01 00:00:00-08:00',
            '2060-01-01 00:00:00-08:00', '2080-01-01 00:00:00-08:00',
            '2100-01-01 00:00:00-08:00', '2120-01-01 00:00:00-08:00',
            '2140-01-01 00:00:00-08:00', '2160-01-01 00:00:00-08:00',
            '2180-01-01 00:00:00-08:00', '2200-01-01 00:00:00-08:00'
        ]
    ], [
        datetime.timedelta(weeks=52),
        [
            '1997-01-01 00:00:00-08:00', '1997-02-01 00:00:00-08:00',
            '1997-03-01 00:00:00-08:00', '1997-04-01 00:00:00-08:00',
            '1997-05-01 00:00:00-07:00', '1997-06-01 00:00:00-07:00',
            '1997-07-01 00:00:00-07:00', '1997-08-01 00:00:00-07:00',
            '1997-09-01 00:00:00-07:00', '1997-10-01 00:00:00-07:00',
            '1997-11-01 00:00:00-08:00', '1997-12-01 00:00:00-08:00'
        ]
    ], [
        datetime.timedelta(days=141),
        [
            '1997-01-01 00:00:00-08:00', '1997-01-15 00:00:00-08:00',
            '1997-02-01 00:00:00-08:00', '1997-02-15 00:00:00-08:00',
            '1997-03-01 00:00:00-08:00', '1997-03-15 00:00:00-08:00',
            '1997-04-01 00:00:00-08:00', '1997-04-15 00:00:00-07:00',
            '1997-05-01 00:00:00-07:00', '1997-05-15 00:00:00-07:00'
        ]
    ], [
        datetime.timedelta(days=40),
        [
            '1997-01-01 00:00:00-08:00', '1997-01-05 00:00:00-08:00',
            '1997-01-09 00:00:00-08:00', '1997-01-13 00:00:00-08:00',
            '1997-01-17 00:00:00-08:00', '1997-01-21 00:00:00-08:00',
            '1997-01-25 00:00:00-08:00', '1997-01-29 00:00:00-08:00',
            '1997-02-01 00:00:00-08:00', '1997-02-05 00:00:00-08:00',
            '1997-02-09 00:00:00-08:00'
        ]
    ], [
        datetime.timedelta(hours=40),
        [
            '1997-01-01 00:00:00-08:00', '1997-01-01 04:00:00-08:00',
            '1997-01-01 08:00:00-08:00', '1997-01-01 12:00:00-08:00',
            '1997-01-01 16:00:00-08:00', '1997-01-01 20:00:00-08:00',
            '1997-01-02 00:00:00-08:00', '1997-01-02 04:00:00-08:00',
            '1997-01-02 08:00:00-08:00', '1997-01-02 12:00:00-08:00',
            '1997-01-02 16:00:00-08:00'
        ]
    ], [
        datetime.timedelta(minutes=20),
        [
            '1997-01-01 00:00:00-08:00', '1997-01-01 00:05:00-08:00',
            '1997-01-01 00:10:00-08:00', '1997-01-01 00:15:00-08:00',
            '1997-01-01 00:20:00-08:00'
        ]
    ], [
        datetime.timedelta(seconds=40),
        [
            '1997-01-01 00:00:00-08:00', '1997-01-01 00:00:05-08:00',
            '1997-01-01 00:00:10-08:00', '1997-01-01 00:00:15-08:00',
            '1997-01-01 00:00:20-08:00', '1997-01-01 00:00:25-08:00',
            '1997-01-01 00:00:30-08:00', '1997-01-01 00:00:35-08:00',
            '1997-01-01 00:00:40-08:00'
        ]
    ])

    tz = dateutil.tz.gettz('Canada/Pacific')
    d1 = datetime.datetime(1997, 1, 1, tzinfo=tz)
    for t_delta, expected in results:
        with rc_context({'_internal.classic_mode': False}):
            d2 = d1 + t_delta
            locator = _create_auto_date_locator(d1, d2, tz)
            st = list(map(str, mdates.num2date(locator(), tz=tz)))
            assert st == expected
Beispiel #49
0
def process_cluster_overlap_stats(
    cluster_overlap_stats,
    max_pvalues,
    plot_args,
    barcode_plot_png_by_maxpvalue,
    cluster_overlap_stats_out_fp,
    plotting_context,
    test_arg_dict=None,
    sample=None,
    cores=1,
    plots_only=False,
    additional_formats=None,
):
    if additional_formats is None:
        additional_formats = tuple()

    if not plots_only:
        if sample is not None and sample < cluster_overlap_stats.hits.shape[1]:
            random_sel = np.random.choice(cluster_overlap_stats.hits.shape[1],
                                          sample,
                                          replace=False)
            cluster_overlap_stats.hits = cluster_overlap_stats.hits.iloc[:,
                                                                         random_sel]

        # print("Calculate test per feature")
        final_test_args = dict(simulate_pval=True,
                               replicate=int(1e4),
                               workspace=1_000_000)
        if test_arg_dict is not None:
            final_test_args.update(test_arg_dict)
        cluster_overlap_stats.test_per_feature(method="hybrid",
                                               cores=cores,
                                               test_args=final_test_args)
        # print("Calculate test per cluster per feature")
        cluster_overlap_stats.test_per_cluster_per_feature()

        with open(cluster_overlap_stats_out_fp, "wb") as fout:
            pickle.dump(cluster_overlap_stats, fout)

        cluster_overlap_stats.hits.to_pickle(
            cluster_overlap_stats_out_fp[:-2] + "_hits.p")
        cluster_overlap_stats.hits.to_csv(cluster_overlap_stats_out_fp[:-2] +
                                          "_hits.tsv",
                                          sep="\t",
                                          index=False)

        cluster_overlap_stats.cluster_pvalues.to_csv(
            cluster_overlap_stats_out_fp[:-2] + "_element-pvalues.tsv",
            sep="\t",
            index=False,
        )

        cluster_overlap_stats.log_odds_ratio.to_csv(
            cluster_overlap_stats_out_fp[:-2] + "_log-odds.tsv",
            sep="\t",
            index=False)

    cluster_overlap_stats: rsp.ClusterOverlapStats
    for max_pvalue in max_pvalues:
        # print("Create barcode figure", "max_pvalue", max_pvalue)
        cluster_overlap_stats_filtered = cluster_overlap_stats.filter(
            "cluster_pvalues", max_pvalue)
        if not cluster_overlap_stats_filtered.cluster_pvalues.empty:
            with mpl.rc_context(plotting_context):
                fig = rsp.barcode_heatmap(
                    cluster_overlap_stats_filtered,
                    **plot_args,
                )
                out_png = barcode_plot_png_by_maxpvalue.format(
                    max_pvalue=max_pvalue)
                fig.set_dpi(90)
                fig.savefig(out_png)
                if "pdf" in additional_formats:
                    fig.savefig(out_png.replace(".png", ".pdf"))
                if "svg" in additional_formats:
                    fig.savefig(out_png.replace(".png", ".svg"))
            plt.close()
        else:
            print(
                f"WARNING: not features left for pvalue {max_pvalue} threshold. No plot created"
            )
from __future__ import absolute_import, division, print_function

import copy

import matplotlib
from matplotlib import pyplot as plt
from matplotlib._pylab_helpers import Gcf

import pytest
try:
    # mock in python 3.3+
    from unittest import mock
except ImportError:
    import mock

with matplotlib.rc_context(rc={'backend': 'Qt5Agg'}):
    qt_compat = pytest.importorskip('matplotlib.backends.qt_compat',
                                    minversion='5')
from matplotlib.backends.backend_qt5 import (
    MODIFIER_KEYS, SUPER, ALT, CTRL, SHIFT)  # noqa

QtCore = qt_compat.QtCore
_, ControlModifier, ControlKey = MODIFIER_KEYS[CTRL]
_, AltModifier, AltKey = MODIFIER_KEYS[ALT]
_, SuperModifier, SuperKey = MODIFIER_KEYS[SUPER]
_, ShiftModifier, ShiftKey = MODIFIER_KEYS[SHIFT]


@pytest.mark.backend('Qt5Agg')
def test_fig_close():
    # save the state of Gcf.figs
def demo(ax, rcparams, title):
    ax.axis('off')
    ax.set_title(title)
    with mpl.rc_context(rc=rcparams):
        for j, sty in enumerate(sty_cycle):
            ax.plot(x, y + j, **sty)
Beispiel #52
0
 def test_using_all_default_major_steps(self):
     with matplotlib.rc_context({'_internal.classic_mode': False}):
         majorsteps = [x[0] for x in self.majorstep_minordivisions]
         np.testing.assert_allclose(majorsteps,
                                    mticker.AutoLocator()._steps)
Beispiel #53
0
 def test_use_offset(self, use_offset):
     with matplotlib.rc_context({'axes.formatter.useoffset': use_offset}):
         tmp_form = mticker.ScalarFormatter()
         assert use_offset == tmp_form.get_useOffset()
Beispiel #54
0
def test_get_color_cycle(cycler, result):
    with mpl.rc_context(rc={"axes.prop_cycle": cycler}):
        assert get_color_cycle() == result
Beispiel #55
0
 def test_basic(self, base, value, expected):
     formatter = mticker.LogFormatterSciNotation(base=base)
     formatter.sublabel = {1, 2, 5, 1.2}
     with matplotlib.rc_context({'text.usetex': False}):
         assert formatter(value) == expected
Beispiel #56
0
def test_use_offset():
    for use_offset in [True, False]:
        with matplotlib.rc_context({'axes.formatter.useoffset': use_offset}):
            tmp_form = mticker.ScalarFormatter()
            assert use_offset == tmp_form.get_useOffset()
Beispiel #57
0
 def test_latex(self, is_latex, usetex, expected):
     fmt = mticker.PercentFormatter(symbol='\\{t}%', is_latex=is_latex)
     with matplotlib.rc_context(rc={'text.usetex': usetex}):
         assert fmt.format_pct(50, 100) == expected
import yt
from matplotlib.animation import FuncAnimation
from matplotlib import rc_context

ts = yt.load('GasSloshingLowRes/sloshing_low_res_hdf5_plt_cnt_*')

plot = yt.SlicePlot(ts[0], 'z', 'density')
plot.set_zlim('density', 8e-29, 3e-26)

fig = plot.plots['density'].figure

# animate must accept an integer frame number. We use the frame number
# to identify which dataset in the time series we want to load
def animate(i):
    ds = ts[i]
    plot._switch_ds(ds)

animation = FuncAnimation(fig, animate, frames=len(ts))

# Override matplotlib's defaults to get a nicer looking font
with rc_context({'mathtext.fontset': 'stix'}):
    animation.save('animation.mp4')
 def test_plot(self, instance):
     with rc_context(rc={'text.usetex': False}):
         plot = instance.plot(figsize=(6.4, 3.8))
         assert isinstance(plot.gca(), SegmentAxes)
         plot.save(BytesIO(), format='png')
         plot.close()
Beispiel #60
0
 def process(self, outputfile=None, close=True):
     with rc_context(rc=self.rcParams):
         return self.draw()