コード例 #1
0
def test_raster_and_psth():
    """Test plotting a raster and PSTH on the same axes."""
    spikes, _ = utils.create_default_fake_spikes()
    visualizations.raster_and_psth(spikes, trial_length=5.0)
    axes = plt.findobj(plt.gcf(), plt.Axes)
    psth_line = plt.findobj(axes[0], plt.Line2D)[0]
    raster_lines = plt.findobj(axes[1], plt.Line2D)[:2]
    
    binsize = 0.01
    ntrials = 2
    assert psth_line.get_xdata().size == (spikes.size / binsize) / ntrials, \
            'visualizations.raster_and_psth did not use trial length correctly'
    assert psth_line.get_ydata().max() == (1 / binsize), \
            'visualizations.raster_and_psth did not correctly compute max spike rate'
    assert raster_lines[0].get_xdata().size == 5, \
            'visualizations.raster_and_psth did not correctly split rasters into trials'
    assert raster_lines[0].get_ydata().size == 5, \
            'visualizations.raster_and_psth did not correctly split rasters into trials'
    assert np.all(raster_lines[0].get_ydata() == 0.0), \
            'visualizations.raster_and_psth did not correctly label rasters'
    assert raster_lines[1].get_xdata().size == 4, \
            'visualizations.raster_and_psth did not correctly split rasters into trials'
    assert raster_lines[1].get_ydata().size == 4, \
            'visualizations.raster_and_psth did not correctly split rasters into trials'
    assert np.all(raster_lines[1].get_ydata() == 1.0), \
            'visualizations.raster_and_psth did not correctly label rasters'
コード例 #2
0
def align_axes(axis_name='xyzc', axes=None):
    """Make sure that the given axes are aligned along the given axis_name
    ('x', 'y', 'c', or any combination thereof (e.g. 'xy' which is the
    default)).  If no axis handles are specified, all axes in the current
    figure are used.
    """
    if axes is None:
        axes = plt.findobj(match=plt.Axes)

    for name in axis_name:
        prop = '%clim' % name
        all_lim = []
        all_axes = []
        for ax in axes:
            if ax is None:
                continue
            try:
                all_lim.append(plt.get(ax, prop))
                all_axes.append(ax)
            except AttributeError:
                for childax in plt.get(ax, 'children'):
                    try:
                        all_lim.append(plt.get(childax, prop))
                        all_axes.append(childax)
                    except:
                        pass
        if all_lim:
            all_lim = np.asarray(all_lim)
            aligned_lim = (all_lim[:,0].min(), all_lim[:,1].max())
            plt.setp(all_axes, prop, aligned_lim)
コード例 #3
0
ファイル: plottools.py プロジェクト: jsherrah/BostonHackDay
def align_axes(axis_name='xyzc', axes=None):
    """Make sure that the given axes are aligned along the given axis_name
    ('x', 'y', 'c', or any combination thereof (e.g. 'xy' which is the
    default)).  If no axis handles are specified, all axes in the current
    figure are used.
    """
    if axes is None:
        axes = plt.findobj(match=plt.Axes)

    for name in axis_name:
        prop = '%clim' % name
        all_lim = []
        all_axes = []
        for ax in axes:
            try:
                all_lim.append(plt.get(ax, prop))
                all_axes.append(ax)
            except AttributeError:
                for childax in plt.get(ax, 'children'):
                    try:
                        all_lim.append(plt.get(childax, prop))
                        all_axes.append(childax)
                    except:
                        pass
        if all_lim:
            all_lim = np.asarray(all_lim)
            aligned_lim = (all_lim[:,0].min(), all_lim[:,1].max())
            plt.setp(all_axes, prop, aligned_lim)
コード例 #4
0
def test_plot_sta():
    """Test visualizations.plot_sta method."""
    # Test plotting temporal component
    temporal_filter, spatial_filter, sta = utils.create_default_fake_filter()
    time = np.arange(temporal_filter.size)
    visualizations.plot_sta(time, temporal_filter)
    line = plt.findobj(plt.gca(), plt.Line2D)[0]
    assert np.all(line.get_xdata() == time), 'Time axis data is incorrect.'
    assert np.all(line.get_ydata() == temporal_filter), 'Temporal filter data is incorrect.'
    plt.close(plt.gcf())

    # Test plotting spatial component
    visualizations.plot_sta(time, spatial_filter)
    img = plt.findobj(plt.gca(), AxesImage)[0]
    desired = (spatial_filter - spatial_filter.mean()) / spatial_filter.var()
    actual = img.get_array()
    assert np.allclose(actual, desired), 'Spatial filter data is incorrect.'
    plt.close(plt.gcf())

    # Test plotting both spatial/temporal components.
    # This code is a bit suspect. `plot_sta` internally calls 
    # `filtertools.decompose`, which will find singular vectors that are
    # unit norm. But then `plot_sta` also calls `spatial`, which does
    # some of its own normalization. The result is that it's difficult
    # to know what scale the true data plotted should have, so this test
    # just normalizes all plots and images.
    fig, axes = visualizations.plot_sta(time, sta)
    img = plt.findobj(axes[0], AxesImage)[0]
    desired = (spatial_filter - spatial_filter.mean())
    desired /= desired.max()
    actual = img.get_array()
    actual /= actual.max()
    assert np.allclose(actual, desired), 'Spatial filter data is incorrect.'

    line = plt.findobj(axes[1], plt.Line2D)[0]
    assert np.all(line.get_xdata() == time), 'Time axis data is incorrect.'
    desired = (temporal_filter - temporal_filter.min())
    desired /= desired.max()
    actual = line.get_ydata()
    actual -= actual.min()
    actual /= actual.max()
    assert np.allclose(desired, actual), 'Temporal filter data is incorrect.'

    # Verify raising a value error when incorrect dimensionality passed
    with pytest.raises(ValueError):
        visualizations.plot_sta(None, np.random.randn(2, 2, 2, 2))
コード例 #5
0
def test_ellipse():
    """Test plotting an ellipse fitted to an RF."""
    temporal_filter, spatial_filter, sta = utils.create_default_fake_filter()
    fig, ax = visualizations.ellipse(sta)
    el = plt.findobj(ax, Ellipse)[0]
    assert np.allclose(el.center, np.array(spatial_filter.shape) / 2.0), \
            'visualizations.ellipse did not compute correct ellipse center'
    assert np.allclose((el.height, el.width), 2.827082246), \
            'visualizations.ellipse computed incorrect width and/or height'
コード例 #6
0
def test_raster():
    """Test plotting a spike raster."""
    spikes, labels = utils.create_default_fake_spikes()
    visualizations.raster(spikes, labels)
    line = plt.findobj(plt.gca(), plt.Line2D)[0]
    assert np.all(line.get_xdata() == spikes), 'Spike times do not match'
    assert np.all(line.get_ydata() == labels), 'Spike labels do not match'

    # Verify exception raised when spikes and labels different length
    with pytest.raises(AssertionError):
        visualizations.raster(np.array((0, 1)), np.array((0, 1, 2)))
コード例 #7
0
def test_spatial_filter():
    """Test plotting a spatial filter directly."""
    # Plot filter
    _, spatial_filter, _ = utils.create_default_fake_filter()
    visualizations.spatial(spatial_filter)
    data = spatial_filter - spatial_filter.mean()

    # Verify data plotted correctly
    img = plt.findobj(plt.gca(), AxesImage)[0]
    assert np.all(img.get_array() == data), 'Spatial filter data is incorrect.'
    plt.close(plt.gcf())

    # Verify data plotted correctly when giving a maximum value
    maxval = np.abs(spatial_filter).max()
    visualizations.spatial(spatial_filter, maxval=maxval)
    img = plt.findobj(plt.gca(), AxesImage)[0]
    assert np.all(img.get_array() == spatial_filter), \
            'Spatial filter data incorrect when passing explicit maxval'
    assert np.all(img.get_clim() == np.array((-maxval, maxval))), \
            'Spatial filter color limits not set correctly.'
    plt.close(plt.gcf())
コード例 #8
0
def test_psth():
    """Test plotting a PSTH."""
    spikes, _ = utils.create_default_fake_spikes()
    visualizations.psth(spikes, trial_length=5.0)
    line = plt.findobj(plt.gca(), plt.Line2D)[0]
    xdata, ydata = line.get_data()
    binsize = 0.01
    ntrials = 2
    assert xdata.size == (spikes.size / binsize) / ntrials, \
            'visualizations.psth did not use trial length correctly'
    assert ydata.max() == (1 / binsize), \
            'visualizations.psth did not correctly compute max spike rate'
コード例 #9
0
def test_play_sta():
    """Test playing an STA as a movie by comparing a known frame."""
    sta = utils.create_default_fake_filter()[-1]
    sta -= sta.mean()
    frame = utils.get_default_movie_frame()
    animation = visualizations.play_sta(sta)
    animation._func(frame)
    imgdata = plt.findobj(plt.gcf(), AxesImage)[0].get_array()
    imgdata -= imgdata.mean()
    data = sta[frame, ...]
    data -= data.mean()
    assert np.allclose(imgdata, data), \
            'visualizations.play_sta did not animate the 3D sta correctly.'
コード例 #10
0
def test_plot_cells():
    """Test plotting ellipses for multiple cells on the same axes."""
    ncells = 2
    stas = [utils.create_default_fake_filter()[-1] for 
            _ in range(ncells)]
    np.random.seed(0)
    visualizations.plot_cells(stas)

    ellipses = plt.findobj(plt.gca(), Ellipse)
    for el in ellipses:
        assert np.allclose(el.center, utils.get_default_filter_size()[0] / 2.), \
                'visualizations.plot_cells did not compute correct ellipse center'
        assert np.allclose((el.height, el.width), 2.827082246), \
                'visualizations.plot_cells computed incorrect width and/or height'
コード例 #11
0
def test_play_rates():
    """Test playing firing rates for cells as a movie."""
    sta = utils.create_default_fake_filter()[-1]
    rates = utils.create_default_fake_rates()
    fig, axes = visualizations.ellipse(sta)
    patch = plt.findobj(axes, Ellipse)[0]
    animation = visualizations.play_rates(rates, patch)

    frame = utils.get_default_movie_frame()
    animation._func(frame)
    cmap = plt.cm.gray(np.arange(255))
    desired_color = cmap[int(rates[frame] / rates.max())]
    assert np.all(patch.get_facecolor()[:3] == desired_color[:3]), \
            'visualizations.play_rates did not set patch color correctly'
コード例 #12
0
def test_temporal_filter():
    """Test plotting a temporal filter directly."""
    # Plot filter
    temporal_filter, _, _ = utils.create_default_fake_filter()
    time = np.arange(temporal_filter.size)
    visualizations.temporal(time, temporal_filter)

    # Verify data plotted correctly
    line = plt.findobj(plt.gca(), plt.Line2D)[0]
    assert np.all(line.get_xdata() == time), \
            'Time axis data is incorrect.'
    assert np.all(line.get_ydata() == temporal_filter), \
            'Temporal filter data is incorrect.'
    plt.close(plt.gcf())
コード例 #13
0
ファイル: save.py プロジェクト: flaport/mplppt
def savefig(filename, fig=None, axis=True):
    """ Export a matplotlib figure to a pptx file 
    
    Args:
        filename: str: the filename of the pptx file to save the matplotlib figure as
        fig: the figure to convert to a pptx slide. If None, plt.gcf() will be used to get the most recent figure.
        axis=True: wether to show the axis ticks and labels or not.
    
    Returns:
        group: the mplppt group containing all the objects that were converted from the matplotlib figure.
    """
    # Get figure to save
    if fig is None:
        fig = gcf()

    # Create ppt group
    p = Group(objects=[])

    # Parse mpl objects:
    for obj in findobj(fig):
        # only keep objects that have an axis:
        if obj.axes is not None:
            # convert lines:
            if isinstance(obj, mpl.lines.Line2D):
                p += Line.from_mpl(obj)
            # convert rectangles:
            if isinstance(obj, mpl.patches.Rectangle):
                p += Rectangle.from_mpl(obj)
            # convert polygons
            if isinstance(obj, mpl.patches.Polygon):
                p += Polygon.from_mpl(obj)
            # convert text
            if isinstance(obj, mpl.text.Text):
                p += Text.from_mpl(obj)
            # convert pcolormesh
            if isinstance(obj, mpl.collections.QuadMesh):
                p += Mesh.from_mpl(obj)

    # create a canvas
    # TODO: Create this with less parameters
    canvas = Canvas.from_mpl(fig.axes[0], axis=axis)
    p += canvas

    # save powerpoint group
    p.save(filename)

    # return powerpoint group
    return p
コード例 #14
0
def test_play_rates():
    """Test playing firing rates for cells as a movie."""
    nx, ny, nt = 10, 10, 50
    sta = utils.create_spatiotemporal_filter(nx, ny, nt)[-1]
    time = np.linspace(0, 10, 100)
    spikes = np.arange(10)
    binned_spikes = spiketools.binspikes(spikes, time)
    rate = spiketools.estfr(binned_spikes, time)

    # Plot cell
    fig, axes = viz.ellipse(sta)
    patch = plt.findobj(axes, Ellipse)[0]
    anim = viz.play_rates(rate, patch)
    filename = os.path.join(IMG_DIR, 'test-rates-movie.png')
    frame = 10
    anim._func(frame)
    plt.savefig(filename)
    assert not compare_images(
        os.path.join(IMG_DIR, 'baseline-rates-movie-frame.png'), filename, 1)
    os.remove(filename)
    plt.close('all')
コード例 #15
0
def get_plotting_area(fig):
    """ get area which is visualized by matplotlib

    Args:
        fig: matplotlib figure to find the area for
    
    Returns:
        xmin, xmax, ymin, ymax: the bounds of the matplotlib figure
    """
    global visualized
    if fig not in visualized:
        # HACK: To get info about spine locations, the axis needs to be visualized first.
        # We choose png export:
        fn = random_name() + ".png"
        plt.savefig(fn)
        os.remove(fn)
        visualized[fig] = True
    spines = plt.findobj(fig, mpl.spines.Spine)
    bboxes = [np.array(spine.get_extents()) for spine in spines]
    xmin = np.min([np.min(bbox[:, 0]) for bbox in bboxes])
    xmax = np.max([np.max(bbox[:, 0]) for bbox in bboxes])
    ymin = np.min([np.min(bbox[:, 1]) for bbox in bboxes])
    ymax = np.max([np.max(bbox[:, 1]) for bbox in bboxes])
    return xmin, xmax, ymin, ymax
コード例 #16
0
def figureToInverseVideo(fig=None, debug=False):
    ##########################################################################
    ##########################################################################
    """
2013 June: backgrounds of boxed text missing. (bbox). Cannot figure out how. 
Added: setp(fff,'facecolor','k') No! Actually, use facecolor='k' option in savefig!
Replaced a bunch of loops with a findobj!
    
    May 2012, ... I've finally made some progress on this. Set figure axis background color, foreground color, and save figure as transparent. Plus, check that black text, black lines .... and then also, harder! other black stuff like patches are dealt with.

sep2012:    Agh. what about stuff drawn in bg colour, eg to over? I want to swtich that to the new bg colour! Not done. Wasneeded for PQ/liberal colour band in rdc/regressoinsQuebec.

debug = True will wait for confirmation at each step, in order to check what's happening
    """
    if fig is None:
        fig=plt.gcf()
    k2w = {'k':'w', (0,0,0):(1,1,1), (0,0,0,0):(1,1,1,0), (0,0,0,1):(1,1,1,1), (0.0,0.0,0.0,1):(1,1,1,1)} # Lookup for blacks to whites
    def obselete_gray2gray(colour):
        """ Invert the gray level for gray colours """
        if hasattr(colour, 'shape'): # Is mpl.array!
            assert len(colour) in [3,4]
            colour=list(colour)
            assert colour[0]<=1
            if colour[0]==colour[1]==colour[2] and color[0] not in [0,1,0.0,1.0]:
                colour[0]= 1-colour[0]
                colour[1]= colour[0]
                colour[2]= colour[0]
                return colour
        return None
    
    def cpause(ss='Press enter'):
        if not debug: return
        plt.show(), plt.draw()
        raw_input(ss)
    cpause('About to invert colours')

    def check_and_set_color_using_function(oo, set_function, colour):
        """  The set_function may be the object's set_color or its set_facecolor, for instance. Thus, all three parameters must be set.
        """
        origc=colour
        try:
            if len(colour)==0: return
        except:
            foo
            
        while hasattr(colour, 'shape') and len(colour)==1:
            colour = colour[0]
        
        if hasattr(colour, 'shape') or isinstance(colour,tuple): # Is mpl.array or tuple
            assert len(colour) in [3,4]
            colour=list(colour)
        """ Invert the gray level for gray colours """
        if len(colour) in [3,4]  and colour[0]==colour[1]==colour[2] and colour[0] not in [0,1,0.0,1.0]:
            assert colour[0]<=1
            colour[0]= 1-colour[0]
            colour[1]= colour[0]
            colour[2]= colour[0]
            set_function(colour)
            if debug: print(' Found gray: {} --> {} in {}'.format(origc,colour, oo.__class__.__name__))
            return
        # Deal here with hex grays (NOT DONE YET
        #
        if tuple(colour) in k2w: # "black" to "white"
            if debug: print(' Found "black": {} --> {} in {}'.format(origc,k2w[tuple(colour)], oo.__class__.__name__))
            set_function(k2w[tuple(colour)])
            return
        if debug: print('     Not changing color {} of {}'.format(colour, oo.__class__.__name__))

    def check_and_set_color(oo):
        if hasattr(oo,'get_color'):
            check_and_set_color_using_function(oo, oo.set_color, oo.get_color())
        if hasattr(oo,'get_edgecolor'):
            check_and_set_color_using_function(oo, oo.set_edgecolor, oo.get_edgecolor())
        if hasattr(oo,'get_facecolor'):
            check_and_set_color_using_function(oo, oo.set_facecolor, oo.get_facecolor())

            

            
    for o in fig.findobj(mpl.lines.Line2D):
        check_and_set_color(o)
    cpause('lines')        
    for o in fig.findobj(mpl.text.Text):
        check_and_set_color(o)
    cpause('text')        

    """ If we do everything with a facecolor, we'll run into some problems
    with conflicts with the axes objects.  So, instead, just do custom-listed shapes
    Actually, htis is no longer the case. By putting the axes.Axes treatment at the end, there is no coflict with this catch-all"""
    treat_shapes = (mpl.patches.Patch, mpl.collections.PolyCollection,mpl.patches.Rectangle)
    found_shapes = [cc.__class__ for cc in fig.findobj(lambda ooo: hasattr(ooo, 'set_facecolor') )]
    #exclude_shapes = (mpl.axes._subplots.AxesSubplot, mpl.figure.Figure, mpl.spines.Spine)
    #mpl.patches.Circle, mpl.patches.Circle, mpl.spines.Spine, mpl.spines.Spine, mpl.spines.Spine, mpl.spines.Spine, mpl.patches.FancyBboxPatch, mpl.patches.Rectangle, 
    for o in fig.findobj(lambda ooo: hasattr(ooo, 'set_facecolor') ):
        check_and_set_color(o)
    cpause('All faces')

    for aaa in plt.findobj(fig,mpl.axes.Axes):
        set_axis_backgroundcolor(aaa,'black') # Not None?
        cpause('Axes bgs')
        set_axis_foregroundcolor(aaa,'white')
        cpause('Axes fgs')
            
    # what about figure itself? See also  facecolor of savefig...
    fig.set_facecolor('k')
    #if fig.get_facecolor()[0] in [.75, 1.0]: # Not sure why not just always set it to 'k'.? (201710cpbl)
    #   fig.set_facecolor('k')
    #elif fig.get_facecolor() not in [(0.0, 0.0, 0.0, 1.0)]:
    #    print('deal with this')
    #    print fig.get_facecolor()
    #    deal_with_this


    return()
コード例 #17
0
fig, ax = plt.subplots(1, 1, figsize=(24, 14))

display = plotting.plot_matrix(corr,
                               reorder=True,
                               labels=new_labels,
                               cmap='RdBu_r',
                               auto_fit=False,
                               axes=ax)

display.axes.set_xticklabels([])
display.axes.set_xticks([])

plt.yticks(rotation=0)
plt.yticks(fontsize=10)

for tt in plt.findobj(fig, mpl.text.Text):
    text = tt.get_text()
    if not text:
        continue
    elif text.replace('-', '').replace('.', '').isnumeric():
        continue
    tt.set_fontsize(8)

plt.findobj(fig, mpl.text.Text)

for spine in ax.spines.values():
    spine.set_visible(False)

fig.savefig('figure_S1.pdf', bbox_inches="tight", dpi=300)
fig.savefig('figure_S1.png', bbox_inches="tight", dpi=300)
コード例 #18
0
ファイル: store.py プロジェクト: biophyscode/omnicalc
def picturesave(savename,
                directory='./',
                meta=None,
                extras=[],
                backup=False,
                dpi=300,
                form='png',
                version=False,
                pdf=False,
                tight=True,
                pad_inches=0,
                figure_held=None,
                loud=True,
                redacted=False):
    """
	Function which saves the global matplotlib figure without overwriting.
	!Note that saving tuples get converted to lists in the metadata so if you notice that your plotter is not 
	overwriting then this is probably why.
	"""
    #! amazing bug: if you keep a comma after meta it makes it a tuple and then there must be a
    #!   one-way conversion to dict when it is written to the metadata of the image and this causes
    #!   the figure counts to keep increasing no matter what. a very subtle error! corrected below
    if type(meta) == tuple:
        if len(meta) != 1 or type(meta[0]) != dict:
            raise Exception('meta must be a dict')
        else:
            meta = meta[0]
    #---automatically share images with group members (note that you could move this to config)
    os.umask(0o002)
    #---earlier import allows users to set Agg so we import here, later
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    #---intervene here to check the wordspace for picture-saving "hooks" that apply to all new pictures
    #---! is it necessary to pass the workspace here?
    if 'work' in globals() and 'picture_hooks' in work.metadata.variables:
        extra_meta = work.metadata.variables['picture_hooks']
        #---redundant keys are not allowed: either they are in picture_hooks or passed to picturesave
        redundant_extras = [i for i in extra_meta if i in meta]
        if any(redundant_extras):
            raise Exception(
                'keys "%r" are incoming via meta but are already part of picture_hooks'
                % redundant_extras)
    #---redacted figures have blurred labels
    if redacted:
        directory_redacted = os.path.join(directory, 'REDACTED')
        if not os.path.isdir(directory_redacted): os.mkdir(directory_redacted)
        directory = directory_redacted
        status('you have requested redacted figures, so they are saved to %s' %
               directory,
               tag='warning')
        import random
        color_back = work.metadata.director.get('redacted_background_color',
                                                '')
        color_fore = work.metadata.director.get('redacted_foreground_color',
                                                'k')
        if 'redacted_scrambler' in work.metadata.director:
            scrambler_code = work.metadata.director['redacted_scrambler']
            try:
                scrambler = eval(scrambler_code)
                scrambler('test text')
            except:
                raise Exception(
                    'failed to evaluate your `redacted_scrambler` from the director: `%s`'
                    % scrambler_code)
        else:
            #! method below is deprecated because it looks silly. best to use hashes
            if False:
                scrambler = lambda x, max_len=12: ''.join(
                    [chr(ord('a') + random.randint(0, 25))
                     for i in x][:max_len])
            scrambler = lambda x, max_len=10: ('#' * len(x))[:max_len]
        num_format = re.compile("^[\-]?[1-9][0-9]*\.?[0-9]+$")
        isnumber = lambda x: re.match(num_format, x)
        for obj in [i for i in plt.findobj() if type(i) == mpl.text.Text]:
            text_this = obj.get_text()
            if text_this != '' and not isnumber(text_this):
                obj.set_text(scrambler(text_this))
                if color_back: obj.set_backgroundcolor(color_back)
                obj.set_color(color_fore)
    #---if version then we choose savename based on the next available index
    if version:
        #---check for this meta
        search = picturefind(savename,
                             directory=directory,
                             meta=meta,
                             loud=loud)
        if not search:
            if meta == None:
                raise Exception('[ERROR] versioned image saving requires meta')
            fns = glob.glob(os.path.join(directory, savename + '.v*'))
            nums = [
                int(re.findall('^.+\.v([0-9]+)\.png', fn)[0]) for fn in fns
                if re.match('^.+\.v[0-9]+\.png', fn)
            ]
            ind = max(nums) + 1 if nums != [] else 1
            savename += '.v%d' % ind
        else:
            savename = re.findall('(.+)\.[a-z]+', os.path.basename(search))[0]
    #---backup if necessary
    savename += '.' + form
    base_fn = os.path.join(directory, savename)
    if loud: status('saving picture to %s' % savename, tag='store')
    if os.path.isfile(base_fn) and backup:
        for i in range(1, 100):
            latestfile = '.'.join(base_fn.split('.')[:-1]) + '.bak' + (
                '%02d' % i) + '.' + base_fn.split('.')[-1]
            if not os.path.isfile(latestfile): break
        if i == 99 and os.path.isfile(latestfile):
            raise Exception('except: too many copies')
        else:
            if loud:
                status('backing up ' + base_fn + ' to ' + latestfile,
                       tag='store')
            os.rename(base_fn, latestfile)
    #---intervene to use the PDF backend if desired
    #---...this is particularly useful for the hatch-width hack
    #---...(search self.output(0.1, Op.setlinewidth) in
    #---...python2.7/site-packages/matplotlib/backends/backend_pdf.py and raise it to e.g. 3.0)
    if pdf and form != 'png':
        raise Exception('can only use PDF conversion when writing png')
    elif pdf:
        alt_name = re.sub('.png$', '.pdf', savename)
        #---holding the figure allows other programs e.g. ipython notebooks to show and save the figure
        (figure_held if figure_held else plt).savefig(
            alt_name,
            dpi=dpi,
            bbox_extra_artists=extras,
            bbox_inches='tight' if tight else None,
            pad_inches=pad_inches if pad_inches else None,
            format=form)
        #---convert pdf to png
        os.system('convert -density %d %s %s' % (dpi, alt_name, base_fn))
        os.remove(alt_name)
    else:
        (figure_held if figure_held else plt).savefig(
            base_fn,
            dpi=dpi,
            bbox_extra_artists=extras,
            bbox_inches='tight' if tight else None,
            pad_inches=pad_inches if pad_inches else None,
            format=form)
    plt.close()
    #---add metadata to png
    if form == 'png' and meta != None:
        im = Image.open(base_fn)
        imgmeta = PngImagePlugin.PngInfo()
        imgmeta.add_text('meta', json.dumps(meta))
        im.save(base_fn, form, pnginfo=imgmeta)
    else:
        print(
            '[WARNING] you are saving as %s and only png allows metadata-versioned pictures'
            % form)
    return base_fn
コード例 #19
0
def figureToGrayscale(fig=None, debug=True):
    ##########################################################################
    ##########################################################################
    """
    May 2012. Rewriting this from scratch, copied from figureToInverseVideo, after having some luck with the latter.
    This is only just started, though.
    """
    if fig is None:
        fig=plt.gcf()
    def cpause(ss='Press enter'):
        if not debug: return
        plt.show(), plt.draw()
        raw_input(ss)

    if 1: # Note reall
        for aaa in plt.findobj(fig,mpl.axes.Axes):
            set_axis_backgroundcolor(aaa,None) # Not None?
            set_axis_foregroundcolor(aaa,'black')
            cpause('anax')
    def meanColour(oo):
        colour=oo.get_color()
        if hasattr(colour, 'shape'): # Is mpl.array!
            #iopoio
            assert len(colour) in [3,4]
            colour=list(colour)
        if colour not in  ['w','white','k','black',(0,0,0),(0,0,0,0),(1,1,1),(1,1,1,1),]:
            if isinstance(colour,str):
                o.set_color('k')
            #Not finished!!!!
        #Not finished!!!!
    def meanFaceEdgeColour(oo):
        colour=oo.get_edgecolor()
        if hasattr(colour, 'shape'): # Is mpl.array!
            if colour.shape in [(1,4)]:
                if sum(colour[0][0:3])==0:
                    # Leave alpha as is; set black to white:
                    colour[0][0:3]=[1,1,1]
        elif len(colour) in [3,4]:
            colour=list(colour)
            if colour in ['k',(0,0,0),(0,0,0,0),]:
                o.set_edgecolor('w')
        else:
            ffffooijoweiruiuiuuoiuiu
        colour=oo.get_facecolor()
        if hasattr(colour, 'shape'): # Is mpl.array!
            if colour.shape in [(1,4)]:
                if sum(colour[0][0:3])==0:
                    # Leave alpha as is; set black to white:
                    colour[0][0:3]=[1,1,1]
        elif len(colour) in [3,4]:
            colour=list(colour)
            if colour in ['k',(0,0,0),(0,0,0,0),]:
                o.set_facecolor('w')
        else:
            ffffooijoweiruiuiuuoiuiu

    for o in fig.findobj(mpl.lines.Line2D):
        meanColour(o)

    cpause('lines')        
    for o in fig.findobj(mpl.text.Text):
        meanColour(o)
    cpause('text')
    """
    for o in fig.findobj(mpl.patches.Patch):
        meanFaceEdgeColour(o)
    for o in fig.findobj(mpl.collections.PolyCollection):
        meanFaceEdgeColour(o)


    for o in fig.findobj(mpl.patches.Rectangle):
        meanFaceEdgeColour(o)
        """
    return()