Exemple #1
0
 def get_color_map(self, levels):
     """Returns gradient of color from green to red.
     """
     sm = ScalarMappable(cmap='RdYlGn_r')
     normed_levels = levels / np.max(levels)
     colors = 255 * sm.to_rgba(normed_levels)[:, :3]
     return ['#%02x%02x%02x' % (r, g, b) for r,g,b in colors]
Exemple #2
0
class ColorbarWidget(QWidget):

    def __init__(self, parent=None):
        super(ColorbarWidget, self).__init__(parent)
        fig = Figure()
        rect = 0.25, 0.05, 0.1, 0.90
        self.cb_axes = fig.add_axes(rect)
        self.canvas = FigureCanvas(fig)
        self.setLayout(QVBoxLayout())
        self.layout().addWidget(self.canvas)
        self.button = QPushButton("Update")
        self.layout().addWidget(self.button)
        self.button.pressed.connect(self._update_cb_scale)

        self._create_colorbar(fig)

    def _create_colorbar(self, fig):
        self.mappable = ScalarMappable(norm=SymLogNorm(0.0001, 1,vmin=-10., vmax=10000.),
                                  cmap=DEFAULT_CMAP)
        self.mappable.set_array([])
        fig.colorbar(self.mappable, ax=self.cb_axes, cax=self.cb_axes)

    def _update_cb_scale(self):
        self.mappable.colorbar.remove()
        rect = 0.25, 0.05, 0.1, 0.90
        self.cb_axes = self.canvas.figure.add_axes(rect)
        self.mappable = ScalarMappable(Normalize(30, 4300),
                                   cmap=DEFAULT_CMAP)
        self.mappable.set_array([])
        self.canvas.figure.colorbar(self.mappable, ax=self.cb_axes, cax=self.cb_axes)
        self.canvas.draw()
Exemple #3
0
def make_graph(data):
	from matplotlib.cm import jet, ScalarMappable
	from matplotlib.colors import Normalize
	g = nx.Graph()
	cnorm = Normalize(vmin=1, vmax=241)
	smap = ScalarMappable(norm=cnorm, cmap=jet)
	edge_list = []
	for k in data:
		tk = k.split('_')
		if len(tk) != 5:
			g.add_node(k)
		else:
			a,b = tk[1],tk[3]
			g.add_node(a)
			g.add_node(b)
			g.add_edge(a,b)
	pos = nx.spring_layout(g)
	nxcols,glabels = [],{}
	for i,node in enumerate(g.nodes()):
		if '_' not in node:
			nxcols.append(smap.to_rgba(int(node)))
		else:
			nxcols.append((0,1,0,0))
	nx.draw_networkx_nodes(g,pos,node_color=nxcols)
	nx.draw_networkx_labels(g,pos)
	nx.draw_networkx_edges(g,pos)
	plt.show()
	return 
Exemple #4
0
def digit_to_rgb(X, scaling=3, shape = (), cmap = 'binary'):
    '''
    Takes as input an intensity array and produces a rgb image due to some color map

    Parameters
    ----------

    X : numpy.ndarray
        intensity matrix as array of shape [M x N]

    scaling : int
        optional. positive integer value > 0

    shape: tuple or list of its , length = 2
        optional. if not given, X is reshaped to be square.

    cmap : str
        name of color map of choice. default is 'binary'

    Returns
    -------

    image : numpy.ndarray
        three-dimensional array of shape [scaling*H x scaling*W x 3] , where H*W == M*N
    '''

    sm = ScalarMappable(cmap = cmap)
    image = sm.to_rgba(enlarge_image(vec2im(X,shape), scaling))[:,:,0:3]
    return image
Exemple #5
0
def create_heatmap(xs, ys, imageSize, blobSize, cmap):
    blob = Image.new('RGBA', (blobSize * 2, blobSize * 2), '#000000')
    blob.putalpha(0)
    colour = 255 / int(math.sqrt(len(xs)))
    draw = ImageDraw.Draw(blob)
    draw.ellipse((blobSize / 2, blobSize / 2, blobSize * 1.5, blobSize * 1.5),
                 fill=(colour, colour, colour))
    blob = blob.filter(ImageFilter.GaussianBlur(radius=blobSize / 2))
    heat = Image.new('RGBA', (imageSize, imageSize), '#000000')
    heat.putalpha(0)
    xScale = float(imageSize - 1) / (max(xs) - min(xs))
    yScale = float(imageSize - 1) / (min(ys) - max(ys))
    xOff = min(xs)
    yOff = max(ys)
    for i in range(len(xs)):
        xPos = int((xs[i] - xOff) * xScale)
        yPos = int((ys[i] - yOff) * yScale)
        blobLoc = Image.new('RGBA', (imageSize, imageSize), '#000000')
        blobLoc.putalpha(0)
        blobLoc.paste(blob, (xPos - blobSize, yPos - blobSize), blob)
        heat = ImageChops.add(heat, blobLoc)

    norm = Normalize(vmin=min(min(heat.getdata())),
                     vmax=max(max(heat.getdata())))
    sm = ScalarMappable(norm, cmap)
    heatArray = pil_to_array(heat)
    rgba = sm.to_rgba(heatArray[:, :, 0], bytes=True)
    rgba[:, :, 3] = heatArray[:, :, 3]
    coloured = Image.fromarray(rgba, 'RGBA')

    return coloured
def scatter3d(x,y,z, cs, colorsMap='jet'):
    cm = plt.get_cmap(colorsMap)
    cNorm = Normalize(vmin=min(cs), vmax=max(cs))
    scalarMap = ScalarMappable(norm=cNorm, cmap=cm)
    ax.scatter(x, y, z, c=scalarMap.to_rgba(cs), s=5, linewidth=0)
    scalarMap.set_array(cs)
    plt.show()
Exemple #7
0
 def getcolor(self, V, F):
     dS = numpy.empty(len(F))
     for i, f in enumerate(F):
         v = V[f][0]
         dS[i] = v[0] * v[0] + v[1] * v[1] + v[2] * v[2]
     cmap = ScalarMappable(cmap='jet')
     cmap.set_array(dS)
     return cmap, cmap.to_rgba(dS)
Exemple #8
0
def _check_cmap_rgb_vals(vals, cmap, vmin=0, vmax=1):
    """Helper function to check RGB values of color images"""
    from matplotlib.colors import Normalize
    from matplotlib.cm import ScalarMappable
    norm = Normalize(vmin, vmax)
    sm = ScalarMappable(norm=norm, cmap=cmap)
    for val, rgb_expected in vals:
        rgb_actual = sm.to_rgba(val)[:-1]
        assert_allclose(rgb_actual, rgb_expected, atol=1e-5)
def plot_events(mapobj,axisobj,catalog,label= None, color='depth', pretty = False, colormap=None, 
             llat = -90, ulat = 90, llon = -180, ulon = 180, figsize=(16,24), 
             par_range = (-90., 120., 30.), mer_range = (0., 360., 60.),
             showHour = False, M_above = 0.0, location = 'World', min_size=1, max_size=8,**kwargs):

    '''Simplified version of plot_event'''

    import matplotlib.pyplot as plt
    from matplotlib.colors import Normalize
    from matplotlib.cm import ScalarMappable

    lats, lons, mags, times, labels, colors = get_event_info(catalog, M_above, llat, ulat, llon, ulon, color, label)
                    
    min_color = min(colors)
    max_color = max(colors)

    if colormap is None:
        if color == "date":
             colormap = plt.get_cmap()
        else:
            # Choose green->yellow->red for the depth encoding.
             colormap = plt.get_cmap("RdYlGn_r")
                
    scal_map = ScalarMappable(norm=Normalize(min_color, max_color),
                                  cmap=colormap)
    scal_map.set_array(np.linspace(0, 1, 1))

    x, y = mapobj(lons, lats)

    min_mag = 0
    max_mag = 10
    if len(mags) > 1:
        frac = [(_i - min_mag) / (max_mag - min_mag) for _i in mags]
        magnitude_size = [(_i * (max_size - min_size)) ** 2 for _i in frac]
        #magnitude_size = [(_i * min_size) for _i in mags]
        #print magnitude_size
        colors_plot = [scal_map.to_rgba(c) for c in colors]
    else:
        magnitude_size = 15.0 ** 2
        colors_plot = "red"

    quakes = mapobj.scatter(x, y, marker='o', s=magnitude_size, c=colors_plot, zorder=10)
    #mapobj.drawmapboundary(fill_color='aqua')
    #mapobj.drawparallels(np.arange(-90,90,30),labels=[1,0,0,0])
    #mapobj.drawmeridians(np.arange(mapobj.lonmin,mapobj.lonmax+30,60),labels=[0,0,0,1])

    # if len(mags) > 1:
    #     cb = mpl.colorbar.ColorbarBase(ax=axisobj, cmap=colormap, orientation='vertical')
    #     cb.set_ticks([0, 0.25, 0.5, 0.75, 1.0])
    #     color_range = max_color - min_color
    #     cb.set_ticklabels([_i.strftime('%Y-%b-%d') if color == "date" else '%.1fkm' % (_i)
    #             for _i in [min_color, min_color + color_range * 0.25,
    #                    min_color + color_range * 0.50,
    #                    min_color + color_range * 0.75, max_color]])

    return quakes
 def _create_colorbar(self, cmap, ncolors, labels, **kwargs):    
     norm = BoundaryNorm(range(0, ncolors), cmap.N)
     mappable = ScalarMappable(cmap=cmap, norm=norm)
     mappable.set_array([])
     mappable.set_clim(-0.5, ncolors+0.5)
     colorbar = plt.colorbar(mappable, **kwargs)
     colorbar.set_ticks(np.linspace(0, ncolors, ncolors+1)+0.5)
     colorbar.set_ticklabels(range(0, ncolors))
     colorbar.set_ticklabels(labels)
     return colorbar
Exemple #11
0
def cm2colors(N, cmap='autumn'):
    """Takes N evenly spaced colors out of cmap and returns a
    list of rgb values"""
    values = range(N)
    cNorm  = Normalize(vmin=0, vmax=N-1)
    scalarMap = ScalarMappable(norm=cNorm, cmap=cmap)
    colors = []
    for i in xrange(N):
        colors.append(scalarMap.to_rgba(values[i]))
    return colors
Exemple #12
0
    def plotTciData(self, ax):
        rdata = self.elecTree.getNode('\electrons::top.tci.results:rad').data()
        sm = ScalarMappable()
        rhoColor = sm.to_rgba(-rdata)

        for i in range(1,11):
            pnode = self.elecTree.getNode('\electrons::top.tci.results:nl_%02d' % i)
            ax.plot(pnode.dim_of().data(), pnode.data(), c = rhoColor[i-1])

        ax.set_ylabel('ne')
Exemple #13
0
def plot_net_layerwise(net, x_spacing=5, y_spacing=10, colors=[], use_labels=True, ax=None, cmap='gist_heat', cbar=False, positions={}):
	if not colors:
		colors = [1] * net.size()
	args = {
		'ax' : ax,
		'node_color' : colors,
		'nodelist' : net.nodes(), # ensure that same order is used throughout for parallel data like colors
		'vmin' : 0,
		'vmax' : 1,
		'cmap' : cmap
	}

	if not positions:
		# compute layer-wise positions of nodes (distance from roots)
		nodes_by_layer = defaultdict(lambda: [])
		def add_to_layer(n,l):
			nodes_by_layer[l].append(n)
		net.bfs_traverse(net.get_roots(), add_to_layer)


		positions = {}
		for l, nodes in nodes_by_layer.iteritems():
			y = -l*y_spacing
			# reorder layer lexicographically
			nodes.sort(key=lambda n: n.get_name())
			width = (len(nodes)-1) * x_spacing
			for i,n in enumerate(nodes):
				x = x_spacing*i - width/2
				positions[n] = (x,y)
	args['pos'] = positions

	if use_labels:
		labels = {n:n.get_name() for n in net.iter_nodes()}
		args['labels'] = labels

	if ax is None:
		ax = plt.figure().add_subplot(1,1,1)
	nxg = net_to_digraph(net)
	nx.draw_networkx(nxg, **args)
	ax.tick_params(axis='x', which='both', bottom='off', top='off', labelbottom='off')
	ax.tick_params(axis='y', which='both', left='off', right='off', labelleft='off')
	
	if cbar:
		color_map = ScalarMappable(cmap=cmap)
		color_map.set_clim(vmin=0, vmax=1)
		color_map.set_array(np.array([0,1]))
		plt.colorbar(color_map, ax=ax)

	ax.set_aspect('equal')
	# zoom out slightly to avoid cropping issues with nodes
	xl = ax.get_xlim()
	yl = ax.get_ylim()
	ax.set_xlim(xl[0]-x_spacing/2, xl[1]+x_spacing/2)
	ax.set_ylim(yl[0]-y_spacing/2, yl[1]+y_spacing/2)
Exemple #14
0
  def _update_scatter(self):
    # Updates the scatter points for changes in *tidx* or *xidx*. This 
    # just changes the face color
    # 
    # CALL THIS AFTER *_update_image*
    #
    if len(self.data_sets) < 2:
      return

    sm = ScalarMappable(norm=self.cbar.norm,cmap=self.cbar.get_cmap())
    colors = sm.to_rgba(self.data_sets[1][self.tidx,:,2])
    self.scatter.set_facecolors(colors)
Exemple #15
0
    def densityplot2(self,modelname='Model',refname='Ref',units = 'mmol m-3',sub = 'med'):
        '''
        opectool like density plot
        
        Ref is in x axis
        Model  in y axis
        
        Args
         - *modelname* (optional) string , default ='Model'
         - *refname*   (optional) string , default ='Ref'
         - *units*     (optional) string , default ='mmol m-3'
         - *sub*       (optional) string , default ='med'
        
        
        Returns: a matplotlib Figure object and a matplotlib Axes object
        '''
        
        fig, ax = plt.subplots()
        plt.title('%s Density plot of %s and %s\nNumber of considered matchups: %s' % (sub, modelname, refname, self.number()))
        cmap = 'spectral_r'
        axis_min = min(self.Ref.min(),self.Model.min())
        axis_max = max(self.Ref.max(),self.Model.max())
        extent = [axis_min, axis_max, axis_min, axis_max]

        hexbin = ax.hexbin(self.Ref, self.Model, bins=None, extent=extent, cmap=cmap)
        data = hexbin.get_array().astype(np.int32)
        MAX = data.max()

        for nticks in range(10,2,-1):
            float_array=np.linspace(0,MAX,nticks)
            int___array = float_array.astype(np.int32)
            if np.all(float_array == int___array ):
                break

        mappable = ScalarMappable(cmap=cmap)
        mappable.set_array(data)
        #fig.colorbar(mappable, ticks = int___array, ax=ax)
        cbar = fig.colorbar(mappable, ax=ax)
        labels = cbar.ax.get_yticklabels()
        FloatNumberFlag = False
        for label in labels:
            numstr = str(label.get_text())
            if numstr.find(".") > -1:
                FloatNumberFlag = True

        if FloatNumberFlag:
            cbar.remove()
            cbar = fig.colorbar(mappable, ticks = int___array, ax=ax)

        ax.set_xlabel('%s %s' % (refname,  units))
        ax.set_ylabel('%s %s' % (modelname,units))
        ax.grid()
        return fig,ax
def generalized_bar_chart(code_matrix, trans_names, code_names, show_it=True, show_trans_names=False, color_map = "jet", legend_labels = None, title=None, horizontal_grid = True):
    ldata = {}
    fig = pylab.figure(facecolor="white", figsize=(12, 4))
    fig.subplots_adjust(left=.05, bottom=.15, right=.98, top=.95)
    code_names = [c for c in range(code_matrix.shape[1])]
    for i, code in enumerate(range(len(code_names))):
        ldata[code] = [code_matrix[j, i] for j in range(len(trans_names))]
    ind = np.arange(len(trans_names))
    width = 1.0 / (len(code_names) + 1)
    traditional_colors = ['red', 'orange', 'yellow', 'green', 'blue', 'purple', 'black', 'grey', 'cyan', 'coral']

    ax = fig.add_subplot(111)
    if title is not None:
        ax.set_title(title, fontsize=10)

    if color_map == "AA_traditional_":
        lcolors = traditional_colors
    else:
        cNorm = mpl_Normalize(vmin = 1, vmax = len(code_names))
        comap = get_cmap(color_map)
        scalar_map = ScalarMappable(norm = cNorm, cmap = comap)
        lcolors = [scalar_map.to_rgba(idx + 1) for idx in range(len(code_names))]

    the_bars = []

    for c in range(len(code_names)):
        new_bars = ax.bar(ind + (c + .5) * width, ldata[code_names[c]], width, color=lcolors[c % (len(lcolors))])
        the_bars.append(new_bars[0])
        # bar_groups.append(bars)
    if show_trans_names:
        ax.set_xticks(ind + .5)
        ax.set_xticklabels(trans_names, size="x-small", rotation= -45)
    else:
        ax.grid(b = horizontal_grid, which = "major", axis = 'y')
        ax.set_xticks(ind + .5)
        ax.set_xticklabels(ind + 1, size="x-small")
        for i in ind[1:]:
            ax.axvline(x = i, linestyle = "--", linewidth = .25, color = 'black')

    if legend_labels != None:
        fontP =FontProperties()
        fontP.set_size('small')
        box = ax.get_position()
        ax.set_position([box.x0, box.y0, box.width * 0.825, box.height])
        # Put a legend to the right of the current axis
        ax.legend(the_bars, legend_labels, loc='center left', bbox_to_anchor=(1, 0.5), prop = fontP)
    ax.set_xlim(right=len(trans_names))

    if show_it:
        fig.show()
    return fig
Exemple #17
0
def plot(countries,values,label='',clim=None,verbose=False):
    """
    Usage: worldmap.plot(countries, values [, label] [, clim])
    """
    countries_shp = shpreader.natural_earth(resolution='110m',category='cultural',
                                            name='admin_0_countries')
    ## Create a plot
    fig = plt.figure()
    ax = plt.axes(projection=ccrs.PlateCarree())
    ## Create a colormap
    cmap = plt.get_cmap('RdYlGn_r')
    if clim:
       vmin = clim[0]
       vmax = clim[1]
    else:
       val = values[np.isfinite(values)]
       mean = val.mean()
       std = val.std()
       vmin = mean-2*std
       vmax = mean+2*std
    norm = Normalize(vmin=vmin,vmax=vmax)
    smap = ScalarMappable(norm=norm,cmap=cmap)
    ax2 = fig.add_axes([0.3, 0.18, 0.4, 0.03])
    cbar = ColorbarBase(ax2,cmap=cmap,norm=norm,orientation='horizontal')
    cbar.set_label(label)
    ## Add countries to the map
    for country in shpreader.Reader(countries_shp).records():
        countrycode = country.attributes['adm0_a3']
        countryname = country.attributes['name_long']
        ## Check for country code consistency
        if countrycode == 'SDS': #South Sudan
           countrycode = 'SSD'
        elif countrycode == 'ROU': #Romania
           countrycode = 'ROM'
        elif countrycode == 'COD': #Dem. Rep. Congo
           countrycode = 'ZAR'
        elif countrycode == 'KOS': #Kosovo
           countrycode = 'KSV'
        if countrycode in countries:
           val = values[countries==countrycode]
           if np.isfinite(val):
              color = smap.to_rgba(val)
           else:
              color = 'grey'
        else:
           color = 'w'
           if verbose:
              print("No data available for "+countrycode+": "+countryname)
        ax.add_geometries(country.geometry,ccrs.PlateCarree(),facecolor=color,label=countryname)
    plt.show()
def custom_colorbar(cmap, ncolors, breaks, **kwargs):
    from matplotlib.colors import BoundaryNorm
    from matplotlib.cm import ScalarMappable
    import matplotlib.colors as mplc

    breaklabels = ['No Counts']+["> %d counts"%(perc) for perc in breaks[:-1]]

    norm = BoundaryNorm(range(0, ncolors), cmap.N)
    mappable = ScalarMappable(cmap=cmap, norm=norm)
    mappable.set_array([])
    mappable.set_clim(-0.5, ncolors+0.5)
    colorbar = plt.colorbar(mappable, **kwargs)
    colorbar.set_ticks(np.linspace(0, ncolors, ncolors+1)+0.5)
    colorbar.set_ticklabels(range(0, ncolors))
    colorbar.set_ticklabels(breaklabels)
    return colorbar
def generate_colors(desired_palette, num_desired_colors):
  """
  Generate an array of color strings, interpolated from the desired_palette.
  desired_palette is from palettable

  Conceptually, this takes a list of colors, and lets you generate any length
  of colors from that array.
  """
  cmap = desired_palette.mpl_colormap
  mappable = ScalarMappable(norm=Normalize(vmin=0, vmax=1), cmap=cmap)

  cols = []
  for i in range(1,num_desired_colors+1):
    (r,g,b,a) = mappable.to_rgba((i - 1) / num_desired_colors)
    cols.append(rgb_to_hex(map(lambda x: int(x*255), [r,g,b])))

  return cols
Exemple #20
0
  def _init_scatter(self):
    # Plots the scatter points at the base of each vector showing the 
    # vertical deformation for the second data set. If there is only 
    # one data set then this function does nothing.
    # 
    # CALL THIS AFTER *_init_image*
    #
    if len(self.data_sets) < 2:
      self.scatter = None 
      return

    sm = ScalarMappable(norm=self.cbar.norm,cmap=self.cbar.get_cmap())
    # use scatter points to show z for second data set 
    colors = sm.to_rgba(self.data_sets[1][self.tidx,:,2])
    self.scatter = self.map_ax.scatter(self.x[:,0],self.x[:,1],
                                       c=colors,s=self.scatter_size,
                                       zorder=2,edgecolor=self.colors[1])
Exemple #21
0
 def _update_cb_scale(self):
     self.mappable.colorbar.remove()
     rect = 0.25, 0.05, 0.1, 0.90
     self.cb_axes = self.canvas.figure.add_axes(rect)
     self.mappable = ScalarMappable(Normalize(30, 4300),
                                cmap=DEFAULT_CMAP)
     self.mappable.set_array([])
     self.canvas.figure.colorbar(self.mappable, ax=self.cb_axes, cax=self.cb_axes)
     self.canvas.draw()
Exemple #22
0
    def set_data(self, x_data, y_data, axis_min, axis_max, matchup_count, log):
        logging.debug('Creating density plot...')

        cmap = 'spectral_r'
        extent = [axis_min, axis_max, axis_min, axis_max]
        if log:
            bin_spec = 'log'
        else:
            bin_spec = None
        hexbin = self.ax.hexbin(x_data, y_data, bins=bin_spec, extent=extent, cmap=cmap)
        data = hexbin.get_array()

        mappable = ScalarMappable(cmap=cmap)
        mappable.set_array(data)
        self.fig.colorbar(mappable, ax=self.ax)

        logging.debug('...success!')
        self.update_title(matchup_count)
def custom_colorbar(cmap, ncolors, labels, **kwargs):    
    """Create a custom, discretized colorbar with correctly formatted/aligned labels.
    
    cmap: the matplotlib colormap object you plan on using for your graph
    ncolors: (int) the number of discrete colors available
    labels: the list of labels for the colorbar. Should be the same length as ncolors.
    """
    from matplotlib.colors import BoundaryNorm
    from matplotlib.cm import ScalarMappable
        
    norm = BoundaryNorm(range(0, ncolors), cmap.N)
    mappable = ScalarMappable(cmap=cmap, norm=norm)
    mappable.set_array([])
    mappable.set_clim(-0.5, ncolors+0.5)
    colorbar = plt.colorbar(mappable, **kwargs)
    colorbar.set_ticks(np.linspace(0, ncolors, ncolors+1)+0.5)
    colorbar.set_ticklabels(range(0, ncolors))
    colorbar.set_ticklabels(labels)
    return colorbar
    def __plot_variance(self):
        pointsMin = self.__calc_min()
        pointsMax = self.__calc_max()

        polys = []
        variance = []
        varMin = 1000
        varMax = 0
        lastX = None
        lastYMin = None
        lastYMax = None
        for x in pointsMin.iterkeys():
            if lastX is None:
                lastX = x
            if lastYMin is None:
                lastYMin = pointsMin[x]
            if lastYMax is None:
                lastYMax = pointsMax[x]
            polys.append([[x, pointsMin[x]],
                          [x, pointsMax[x]],
                          [lastX, lastYMax],
                          [lastX, lastYMin],
                          [x, pointsMin[x]]])
            lastX = x
            lastYMin = pointsMin[x]
            lastYMax = pointsMax[x]

            var = pointsMax[x] - pointsMin[x]
            variance.append(var)
            varMin = min(varMin, var)
            varMax = max(varMax, var)

        norm = Normalize(vmin=varMin, vmax=varMax)
        sm = ScalarMappable(norm, self.colourMap)
        colours = sm.to_rgba(variance)

        pc = PolyCollection(polys)
        pc.set_gid('plot')
        pc.set_norm(norm)
        pc.set_color(colours)
        self.axes.add_collection(pc)

        return None, None
Exemple #25
0
def pyplot_bar(y, cmap='Blues'):
	""" Make a good looking pylot bar plot.

	Use a colormap to color the bars.

	y: height of bars
	cmap: colormap, defaults to 'Blues'
	"""
	import matplotlib.pyplot as plt
	
	from matplotlib.colors import Normalize
	from matplotlib.cm import ScalarMappable

	vmax = numpy.max(y)
	vmin = (numpy.min(y)*3. - vmax)/2.
	
	colormap = ScalarMappable(norm=Normalize(vmin, vmax), cmap='Blues')

	plt.bar(numpy.arange(len(y)), y, color=colormap.to_rgba(y), align='edge', width=1.0)
Exemple #26
0
    def plotThomsonEdgeData(self, ax):
        proNode = self.elecTree.getNode('\ELECTRONS::TOP.YAG_EDGETS.RESULTS:NE')
        rhoNode = self.elecTree.getNode('\ELECTRONS::TOP.YAG_EDGETS.RESULTS:RMID')

        rpro = proNode.data()
        rrho = rhoNode.data()
        rtime = proNode.dim_of().data()

        goodTimes = rrho[0] > 0

        pro = rpro[:,goodTimes]
        rho = rrho[:,goodTimes]
        time = rtime[goodTimes]

        sm = ScalarMappable()
        rhoColor = sm.to_rgba(-rho)

        for i in range(rpro.shape[0]):
            ax.plot(time, pro[i], c=np.mean(rhoColor[i],axis=0))

        ax.set_ylabel('ne')
Exemple #27
0
    def multiTimeTrace(self, i, time, rhoFlat, data, ylabel, reverse=True):
        if reverse:
            sm = ScalarMappable(cmap='gist_rainbow')
        else:
            sm = ScalarMappable(cmap='gist_rainbow_r')

        rhoColor = sm.to_rgba(rhoFlat)

        for j in range(len(rhoFlat)):
            if len(time.shape) > 1:
                if (data.shape[0] < data.shape[1]):
                    self.axes[i].plot(time[j,:], data[j,:], c=rhoColor[j], linestyle='-', marker='.')
                else:
                    self.axes[i].plot(time[:,j], data[:,j], c=rhoColor[j], linestyle='-', marker='.')
            else:
                if (data.shape[0] < data.shape[1]):
                    self.axes[i].plot(time, data[j,:], c=rhoColor[j], linestyle='-', marker='.')
                else:
                    self.axes[i].plot(time, data[:,j], c=rhoColor[j], linestyle='-', marker='.')

        self.axes[i].set_ylabel(ylabel)
Exemple #28
0
    def plot_res(self, ax, do_label=True, tag_leafs=False, zlim=False,
                 cmap='jet_r'):

        from matplotlib.pyplot import get_cmap
        from matplotlib import patheffects
        from matplotlib.colors import LogNorm
        from matplotlib.cm import ScalarMappable
        
        # Create a color map using either zlim as given or max/min resolution.
        cNorm = LogNorm(vmin=self.dx_min, vmax=self.dx_max, clip=True)
        cMap  = ScalarMappable(cmap=get_cmap(cmap), norm=cNorm)
            
        dx_vals = {}
        
        for key in self.tree:
            if self[key].isLeaf:
                color = cMap.to_rgba(self[key].dx)
                dx_vals[self[key].dx] = 1.0
                #if self[key].dx in res_colors:
                #    dx_vals[self[key].dx] = 1.0
                #    color=#res_colors[self[key].dx]
                #else:
                #    color='k'
                self[key].plot_res(ax, fc=color, label=key*tag_leafs)

        if do_label:
            
            ax.annotate('Resolution:', [1.02,0.99], xycoords='axes fraction', 
                        color='k',size='medium')
            for i,key in enumerate(sorted(dx_vals.keys())):
                #dx_int = log2(key)
                if key<1:
                    label = '1/%i' % (key**-1)
                else:
                    label = '%i' % key
                ax.annotate('%s $R_{E}$'%label, [1.02,0.87-i*0.1],
                            xycoords='axes fraction', color=cMap.to_rgba(key),
                            size='x-large',path_effects=[patheffects.withStroke(
                                linewidth=1,foreground='k')])
class ColorMapper(object):
    def __init__(self, bottom_val, top_val, color_palette_name):
        cnorm = mpl_Normalize(vmin=bottom_val, vmax=top_val)
        comap = get_cmap(color_palette_name)
        self.scalar_map = ScalarMappable(norm=cnorm, cmap=comap)

    @staticmethod
    def rgb_to_hex(rgb):
        res = tuple([int(c * 255) for c in rgb])
        return '#%02x%02x%02x' % res

    def color_from_val(self, val):
        return self.rgb_to_hex(self.scalar_map.to_rgba(val)[:3])
def _custom_colorbar(cmap, ncolors, labels, **kwargs):
        """Create a custom, discretized colorbar with correctly formatted/aligned labels.
        It was inspired mostly by the example provided in http://beneathdata.com/how-to/visualizing-my-location-history/

        :param cmap: the matplotlib colormap object you plan on using for your graph
        :param ncolors: (int) the number of discrete colors available
        :param labels: the list of labels for the colorbar. Should be the same length as ncolors.

        :return: custom colorbar
        """

        if ncolors <> len(labels):
            raise MapperError("Number of colors is not compatible with the number of labels")
        else:
            norm = BoundaryNorm(range(0, ncolors), cmap.N)
            mappable = ScalarMappable(cmap=cmap)
            mappable.set_array([])
            mappable.set_clim(-0.5, ncolors+0.5)
            colorbar = plt.colorbar(mappable, **kwargs)
            colorbar.set_ticks(np.linspace(0, ncolors, ncolors+1)+0.5)
            colorbar.set_ticklabels(range(0, ncolors))
            colorbar.set_ticklabels(labels)
            return colorbar
Exemple #31
0
def plot2d(x: Union[core.NeuronObject, core.Volume, np.ndarray,
                    List[Union[core.NeuronObject, np.ndarray, core.Volume]]],
           method: Union[Literal['2d'], Literal['3d'],
                         Literal['3d_complex']] = '3d',
           **kwargs) -> Tuple[mpl.figure.Figure, mpl.axes.Axes]:
    """Generate 2D plots of neurons and neuropils.

    The main advantage of this is that you can save plot as vector graphics.

    Important
    ---------
    This function uses matplotlib which "fakes" 3D as it has only very limited
    control over layering objects in 3D. Therefore neurites are not necessarily
    plotted in the right Z order. This becomes especially troublesome when
    plotting a complex scene with lots of neurons criss-crossing. See the
    ``method`` parameter for details. All methods use orthogonal projection.

    Parameters
    ----------
    x :                 TreeNeuron | MeshNeuron | NeuronList | Volume | Dotprops | np.ndarray
                        Objects to plot:

                        - multiple objects can be passed as list (see examples)
                        - numpy array of shape (n,3) is intepreted as points for
                          scatter plots
    method :            '2d' | '3d' (default) | '3d_complex'
                        Method used to generate plot. Comes in three flavours:

                        1. '2d' uses normal matplotlib. Neurons are plotted on
                           top of one another in the order their are passed to
                           the function. Use the ``view`` parameter (below) to
                           set the view (default = xy).
                        2. '3d' uses matplotlib's 3D axis. Here, matplotlib
                           decide the depth order (zorder) of plotting. Can
                           change perspective either interacively or by code
                           (see examples).
                        3. '3d_complex' same as 3d but each neuron segment is
                           added individually. This allows for more complex
                           zorders to be rendered correctly. Slows down
                           rendering though.
    soma :              bool, default=True
                        Plot soma if one exists. Size of the soma is determined
                        by the neuron's ``.soma_radius`` property which defaults
                        to the "radius" column for ``TreeNeurons``.
    connectors :        bool, default=True
                        Plot connectors.
    connectors_only :   boolean, default=False
                        Plot only connectors, not the neuron.
    cn_size :           int | float, default = 1
                        Size of connectors.
    linewidth :         int | float, default=.5
                        Width of neurites. Also accepts alias ``lw``.
    linestyle :         str, default='-'
                        Line style of neurites. Also accepts alias ``ls``.
    autoscale :         bool, default=True
                        If True, will scale the axes to fit the data.
    scalebar :          int | float | str | pint.Quantity, default=False
                        Adds scale bar. Provide integer, float or str to set
                        size of scalebar. Int|float are assumed to be in same
                        units as data. You can specify units in as string:
                        e.g. "1 um". For methods '3d' and '3d_complex', this
                        will create an axis object.
    ax :                matplotlib ax, default=None
                        Pass an ax object if you want to plot on an existing
                        canvas. Must match ``method`` - i.e. 2D or 3D axis.
    figsize :           tuple, default=(8, 8)
                        Size of figure.
    color :             None | str | tuple | list | dict, default=None
                        Use single str (e.g. ``'red'``) or ``(r, g, b)`` tuple
                        to give all neurons the same color. Use ``list`` of
                        colors to assign colors: ``['red', (1, 0, 1), ...].
                        Use ``dict`` to map colors to neuron IDs:
                        ``{id: (r, g, b), ...}``.
    palette :           str | array | list of arrays, default=None
                        Name of a matplotlib or seaborn palette. If ``color`` is
                        not specified will pick colors from this palette.
    color_by :          str | array | list of arrays, default = None
                        Can be the name of a column in the node table of
                        ``TreeNeurons`` or an array of (numerical or
                        categorical) values for each node. Numerical values will
                        be normalized. You can control the normalization by
                        passing a ``vmin`` and/or ``vmax`` parameter.
    shade_by :          str | array | list of arrays, default=None
                        Similar to ``color_by`` but will affect only the alpha
                        channel of the color. If ``shade_by='strahler'`` will
                        compute Strahler order if not already part of the node
                        table (TreeNeurons only). Numerical values will be
                        normalized. You can control the normalization by passing
                        a ``smin`` and/or ``smax`` parameter.
    alpha :             float [0-1], default=.9
                        Alpha value for neurons. Overriden if alpha is provided
                        as fourth value in ``color`` (rgb*a*). You can override
                        alpha value for connectors by using ``cn_alpha``.
    clusters :          list, default=None
                        A list assigning a cluster to each neuron (e.g.
                        ``[0, 0, 0, 1, 1]``). Overrides ``color`` and uses
                        ``palette`` to generate colors according to clusters.
    depth_coloring :    bool, default=False
                        If True, will color encode depth (Z). Overrides
                        ``color``. Does not work with ``method = '3d_complex'``.
    depth_scale :       bool, default=True
                        If True and ``depth_coloring=True`` will plot a scale.
    cn_mesh_colors :    bool, default=False
                        If True, will use the neuron's color for its connectors.
    group_neurons :     bool, default=False
                        If True, neurons will be grouped. Works with SVG export
                        (not PDF). Does NOT work with ``method='3d_complex'``.
    scatter_kws :       dict, default={}
                        Parameters to be used when plotting points. Accepted
                        keywords are: ``size`` and ``color``.
    view :              tuple, default = ("x", "y")
                        Sets view for ``method='2d'``.
    orthogonal :        bool, default=True
                        Whether to use orthogonal or perspective view for
                        methods '3d' and '3d_complex'.
    volume_outlines :   bool | "both", default=True
                        If True will plot volume outline with no fill. Only
                        works with `method="2d"`.
    dps_scale_vec :     float
                        Scale vector for dotprops.
    rasterize :         bool, default=False
                        Neurons produce rather complex vector graphics which can
                        lead to large files when saving to SVG, PDF or PS. Use
                        this parameter to rasterize neurons and meshes/volumes
                        (but not axes or labels) to reduce file size.

    Examples
    --------
    .. plot::
       :context: close-figs

        >>> import navis
        >>> import matplotlib.pyplot as plt

        Plot list of neurons as simple 2d

        >>> nl = navis.example_neurons()
        >>> fig, ax = navis.plot2d(nl, method='2d', view=('x', '-y'))
        >>> plt.show() # doctest: +SKIP

    Add a volume

    .. plot::
       :context: close-figs

        >>> vol = navis.example_volume('LH')
        >>> fig, ax = navis.plot2d([nl, vol], method='2d', view=('x', '-y'))
        >>> plt.show() # doctest: +SKIP

    Change neuron colors

    .. plot::
       :context: close-figs

        >>> fig, ax = navis.plot2d(nl,
        ...                        method='2d',
        ...                        view=('x', '-y'),
        ...                        color=['r', 'g', 'b', 'm', 'c', 'y'])
        >>> plt.show() # doctest: +SKIP

    Plot in "fake" 3D

    .. plot::
       :context: close-figs

        >>> fig, ax = navis.plot2d(nl, method='3d')
        >>> plt.show() # doctest: +SKIP
        >>> # In an interactive window you can dragging the plot to rotate

    Plot in "fake" 3D and change perspective

    .. plot::
       :context: close-figs

        >>> fig, ax = navis.plot2d(nl, method='3d')
        >>> # Change view to frontal (for example neurons)
        >>> ax.azim = ax.elev = 90
        >>> # Change view to lateral
        >>> ax.azim, ax.elev = 180, 180
        >>> ax.elev = 0
        >>> # Change view to top
        >>> ax.azim, ax.elev = 90, 180
        >>> # Tilted top view
        >>> ax.azim, ax.elev = -130, -150
        >>> # Move camera
        >>> ax.dist = 6
        >>> plt.show() # doctest: +SKIP

    Plot using depth-coloring

    .. plot::
       :context: close-figs

        >>> fig, ax = navis.plot2d(nl, method='3d', depth_coloring=True)
        >>> plt.show() # doctest: +SKIP

    To close all figures

    >>> plt.close('all')

    See the :ref:`plotting tutorial <plot_intro>` for more examples.


    Returns
    -------
    fig, ax :      matplotlib figure and axis object

    See Also
    --------
    :func:`navis.plot3d`
            Use this if you want interactive, perspectively correct renders
            and if you don't need vector graphics as outputs.
    :func:`navis.plot1d`
            A nifty way to visualise neurons in a single dimension.
    :func:`navis.plot_flat`
            Plot neurons as flat structures (e.g. dendrograms).

    """
    # Filter kwargs
    _ACCEPTED_KWARGS = [
        'soma', 'connectors', 'connectors_only', 'ax', 'color', 'colors', 'c',
        'view', 'scalebar', 'cn_mesh_colors', 'linewidth', 'cn_size',
        'cn_alpha', 'orthogonal', 'group_neurons', 'scatter_kws', 'figsize',
        'linestyle', 'rasterize', 'clusters', 'synapse_layout', 'alpha',
        'depth_coloring', 'autoscale', 'depth_scale', 'ls', 'lw',
        'volume_outlines', 'radius', 'dps_scale_vec', 'palette', 'color_by',
        'shade_by', 'vmin', 'vmax', 'smin', 'smax', 'norm_global'
    ]
    wrong_kwargs = [a for a in kwargs if a not in _ACCEPTED_KWARGS]
    if wrong_kwargs:
        raise KeyError(f'Unknown kwarg(s): {", ".join(wrong_kwargs)}. '
                       f'Currently accepted: {", ".join(_ACCEPTED_KWARGS)}')

    _METHOD_OPTIONS = ['2d', '3d', '3d_complex']
    if method not in _METHOD_OPTIONS:
        raise ValueError(f'Unknown method "{method}". Please use either: '
                         f'{",".join(_METHOD_OPTIONS)}')

    connectors = kwargs.get('connectors', False)
    connectors_only = kwargs.get('connectors_only', False)
    ax = kwargs.pop('ax', None)
    scalebar = kwargs.get('scalebar', None)

    # Depth coloring
    depth_coloring = kwargs.get('depth_coloring', False)
    depth_scale = kwargs.get('depth_scale', True)

    scatter_kws = kwargs.get('scatter_kws', {})
    autoscale = kwargs.get('autoscale', True)

    # Parse objects
    (neurons, volumes, points, _) = utils.parse_objects(x)

    # Generate colors
    colors = kwargs.pop('color', kwargs.pop('c', kwargs.pop('colors', None)))
    palette = kwargs.get('palette', None)
    color_by = kwargs.get('color_by', None)
    shade_by = kwargs.get('shade_by', None)

    # Generate the colormaps
    (neuron_cmap,
     volumes_cmap) = prepare_colormap(colors,
                                      neurons=neurons,
                                      volumes=volumes,
                                      palette=palette,
                                      clusters=kwargs.get('clusters', None),
                                      alpha=kwargs.get('alpha', None),
                                      color_range=1)

    if not isinstance(color_by, type(None)):
        if not palette:
            raise ValueError(
                'Must provide `palette` (e.g. "viridis") argument '
                'if using `color_by`')
        neuron_cmap = vertex_colors(neurons,
                                    by=color_by,
                                    use_alpha=False,
                                    palette=palette,
                                    norm_global=kwargs.get(
                                        'norm_global', True),
                                    vmin=kwargs.get('vmin', None),
                                    vmax=kwargs.get('vmax', None),
                                    na=kwargs.get('na', 'raise'),
                                    color_range=1)

    if not isinstance(shade_by, type(None)):
        alphamap = vertex_colors(
            neurons,
            by=shade_by,
            use_alpha=True,
            palette='viridis',  # palette is irrelevant here
            norm_global=kwargs.get('norm_global', True),
            vmin=kwargs.get('smin', None),
            vmax=kwargs.get('smax', None),
            na=kwargs.get('na', 'raise'),
            color_range=1)

        new_colormap = []
        for c, a in zip(neuron_cmap, alphamap):
            if not (isinstance(c, np.ndarray) and c.ndim == 2):
                c = np.tile(c, (a.shape[0], 1))

            if c.shape[1] == 4:
                c[:, 3] = a[:, 3]
            else:
                c = np.insert(c, 3, a[:, 3], axis=1)

            new_colormap.append(c)
        neuron_cmap = new_colormap

    # Set axis projection
    if method in ['3d', '3d_complex']:
        if kwargs.get('orthogonal', True):
            proj3d.persp_transformation = _orthogonal_proj
        else:
            proj3d.persp_transformation = _perspective_proj

    # Generate axes
    if not ax:
        if method == '2d':
            fig, ax = plt.subplots(figsize=kwargs.get('figsize', (8, 8)))
            ax.set_aspect('equal')
        elif method in ['3d', '3d_complex']:
            fig = plt.figure(figsize=kwargs.get('figsize',
                                                plt.figaspect(1) * 1.5))
            ax = fig.add_subplot(111, projection='3d')

            # This sets front view
            ax.azim = -90
            ax.elev = 0
            ax.dist = 7
            # Disallowed for 3D in matplotlib 3.1.0
            # ax.set_aspect('equal')
        # Make background transparent (nicer for dark themes)
        fig.patch.set_alpha(0)
        ax.patch.set_alpha(0)
    # Check if correct axis were provided
    else:
        if not isinstance(ax, mpl.axes.Axes):
            raise TypeError('Ax must be of type "mpl.axes.Axes", '
                            f'not "{type(ax)}"')
        fig = ax.get_figure()
        if method in ['3d', '3d_complex']:
            if ax.name != '3d':
                raise TypeError('Axis must be 3d.')
        elif method == '2d':
            if ax.name == '3d':
                raise TypeError('Axis must be 2d.')

    ax.had_data = ax.has_data()

    # Prepare some stuff for depth coloring
    if depth_coloring and not neurons.empty:
        if method == '3d_complex':
            raise Exception(
                f'Depth coloring unavailable for method "{method}"')
        elif method == '2d':
            bbox = neurons.bbox
            # Add to kwargs
            xy = [
                v.replace('-', '').replace('+', '')
                for v in kwargs.get('view', ('x', 'y'))
            ]
            z_ix = [
                v[1] for v in [('x', 0), ('y', 1), ('z', 2)] if v[0] not in xy
            ]

            kwargs['norm'] = plt.Normalize(vmin=bbox[z_ix, 0],
                                           vmax=bbox[z_ix, 1])

    # Plot volumes first
    if volumes:
        for i, v in enumerate(volumes):
            _ = _plot_volume(v, volumes_cmap[i], method, ax, **kwargs)

    # Create lines from segments
    visuals = {}
    for i, neuron in enumerate(
            config.tqdm(neurons,
                        desc='Plot neurons',
                        leave=False,
                        disable=config.pbar_hide | len(neurons) < 2)):
        if not connectors_only:
            if isinstance(neuron, core.TreeNeuron) and kwargs.get(
                    'radius', False):
                _neuron = conversion.tree2meshneuron(neuron)
                _neuron.connectors = neuron.connectors
                neuron = _neuron

            if isinstance(neuron, core.TreeNeuron) and neuron.nodes.empty:
                logger.warning(f'Skipping TreeNeuron w/o nodes: {neuron.id}')
            elif isinstance(neuron,
                            core.MeshNeuron) and neuron.faces.size == 0:
                logger.warning(f'Skipping MeshNeuron w/o faces: {neuron.id}')
            elif isinstance(neuron, core.Dotprops) and neuron.points.size == 0:
                logger.warning(f'Skipping Dotprops w/o points: {neuron.id}')
            elif isinstance(neuron, core.TreeNeuron):
                lc, sc = _plot_skeleton(neuron, neuron_cmap[i], method, ax,
                                        **kwargs)
                # Keep track of visuals related to this neuron
                visuals[neuron] = {'skeleton': lc, 'somata': sc}
            elif isinstance(neuron, core.MeshNeuron):
                m = _plot_mesh(neuron, neuron_cmap[i], method, ax, **kwargs)
                visuals[neuron] = {'mesh': m}
            elif isinstance(neuron, core.Dotprops):
                dp = _plot_dotprops(neuron, neuron_cmap[i], method, ax,
                                    **kwargs)
                visuals[neuron] = {'dotprop': dp}
            elif isinstance(neuron, core.VoxelNeuron):
                dp = _plot_voxels(neuron, neuron_cmap[i], method, ax, kwargs,
                                  **scatter_kws)
                visuals[neuron] = {'dotprop': dp}
            else:
                raise TypeError(
                    f"Don't know how to plot neuron of type '{type(neuron)}' ")

        if (connectors or connectors_only) and neuron.has_connectors:
            _ = _plot_connectors(neuron, neuron_cmap[i], method, ax, **kwargs)

    for p in points:
        _ = _plot_scatter(p, method, ax, kwargs, **scatter_kws)

    if autoscale:
        if method == '2d':
            ax.autoscale(tight=True)
        elif method in ['3d', '3d_complex']:
            # Make sure data lims are set correctly
            update_axes3d_bounds(ax)
            # Rezie to have equal aspect
            set_axes3d_equal(ax)

    if scalebar is not None:
        _ = _add_scalebar(scalebar, neurons, method, ax)

    def set_depth():
        """Set depth information for neurons according to camera position."""
        # Get projected coordinates
        proj_co = mpl_toolkits.mplot3d.proj3d.proj_points(
            all_co, ax.get_proj())

        # Get min and max of z coordinates
        z_min, z_max = min(proj_co[:, 2]), max(proj_co[:, 2])

        # Generate a new normaliser
        norm = plt.Normalize(vmin=z_min, vmax=z_max)

        # Go over all neurons and update Z information
        for neuron in visuals:
            # Get this neurons colletion and coordinates
            if 'skeleton' in visuals[neuron]:
                c = visuals[neuron]['skeleton']
                this_co = c._segments3d[:, 0, :]
            elif 'mesh' in visuals[neuron]:
                c = visuals[neuron]['mesh']
                # Note that we only get every third position -> that's because
                # these vectors actually represent faces, i.e. each vertex
                this_co = c._vec.T[::3, [0, 1, 2]]
            else:
                raise ValueError(
                    f'Neither mesh nor skeleton found for neuron {neuron.id}')

            # Get projected coordinates
            this_proj = mpl_toolkits.mplot3d.proj3d.proj_points(
                this_co, ax.get_proj())

            # Normalise z coordinates
            ns = norm(this_proj[:, 2]).data

            # Set array
            c.set_array(ns)

            # No need for normaliser - already happened
            c.set_norm(None)

            if (isinstance(neuron, core.TreeNeuron) and
                    not isinstance(getattr(neuron, 'soma', None), type(None))):
                # Get depth of soma(s)
                soma = utils.make_iterable(neuron.soma)
                soma_co = neuron.nodes.set_index('node_id').loc[soma][[
                    'x', 'y', 'z'
                ]].values
                soma_proj = mpl_toolkits.mplot3d.proj3d.proj_points(
                    soma_co, ax.get_proj())
                soma_cs = norm(soma_proj[:, 2]).data

                # Set soma color
                for cs, s in zip(soma_cs, visuals[neuron]['somata']):
                    s.set_color(cmap(cs))

    def Update(event):
        set_depth()

    if depth_coloring:
        cmap = mpl.cm.jet
        if method == '2d' and depth_scale:
            sm = ScalarMappable(norm=kwargs['norm'], cmap=cmap)
            fig.colorbar(sm, ax=ax, fraction=.075, shrink=.5, label='Depth')
        elif method == '3d':
            # Collect all coordinates
            all_co = []
            for n in visuals:
                if 'skeleton' in visuals[n]:
                    all_co.append(visuals[n]['skeleton']._segments3d[:, 0, :])
                if 'mesh' in visuals[n]:
                    all_co.append(visuals[n]['mesh']._vec.T[:, [0, 1, 2]])

            all_co = np.concatenate(all_co, axis=0)
            fig.canvas.mpl_connect('draw_event', Update)
            set_depth()

    plt.axis('off')

    return fig, ax
Exemple #32
0
    def initialize_chart(self):
        x_attr = self.vis.get_attr_by_channel("x")[0]
        y_attr = self.vis.get_attr_by_channel("y")[0]

        x_attr_abv = x_attr.attribute
        y_attr_abv = y_attr.attribute

        if len(x_attr.attribute) > 25:
            x_attr_abv = x_attr.attribute[:15] + "..." + x_attr.attribute[-10:]
        if len(y_attr.attribute) > 25:
            y_attr_abv = y_attr.attribute[:15] + "..." + y_attr.attribute[-10:]

        df = self.data.dropna()

        x_pts = df[x_attr.attribute]
        y_pts = df[y_attr.attribute]

        set_fig_code = ""
        plot_code = ""

        color_attr = self.vis.get_attr_by_channel("color")
        if len(color_attr) == 1:
            color_attr_name = color_attr[0].attribute
            color_attr_type = color_attr[0].data_type
            colors = df[color_attr_name].values
            plot_code += f"colors = df['{color_attr_name}'].values\n"
            unique = list(set(colors))
            vals = [unique.index(i) for i in colors]
            if color_attr_type == "quantitative":
                self.fig, self.ax = matplotlib_setup(7, 5)
                set_fig_code = "fig, ax = plt.subplots(figsize=(7, 5))\n"
                self.ax.scatter(x_pts, y_pts, c=vals, cmap="Blues", alpha=0.5)
                plot_code += f"ax.scatter(x_pts, y_pts, c={vals}, cmap='Blues', alpha=0.5)\n"
                my_cmap = plt.cm.get_cmap("Blues")
                max_color = max(colors)
                sm = ScalarMappable(cmap=my_cmap, norm=plt.Normalize(0, max_color))
                sm.set_array([])

                cbar = plt.colorbar(sm, label=color_attr_name)
                cbar.outline.set_linewidth(0)
                plot_code += f"my_cmap = plt.cm.get_cmap('Blues')\n"
                plot_code += f"""sm = ScalarMappable(
                    cmap=my_cmap, 
                    norm=plt.Normalize(0, {max_color}))\n"""

                plot_code += f"cbar = plt.colorbar(sm, label='{color_attr_name}')\n"
                plot_code += f"cbar.outline.set_linewidth(0)\n"
            else:
                if len(unique) >= 16:
                    unique = unique[:16]

                maxlen = 0
                for i in range(len(unique)):
                    unique[i] = str(unique[i])
                    if len(unique[i]) > 26:
                        unique[i] = unique[i][:26] + "..."
                    if len(unique[i]) > maxlen:
                        maxlen = len(unique[i])
                if maxlen > 20:
                    self.fig, self.ax = matplotlib_setup(9, 5)
                    set_fig_code = "fig, ax = plt.subplots(figsize=(9, 5))\n"
                else:
                    self.fig, self.ax = matplotlib_setup(7, 5)
                    set_fig_code = "fig, ax = plt.subplots(figsize=(7, 5))\n"

                cmap = "Set1"
                if len(unique) > 9:
                    cmap = "tab20c"
                scatter = self.ax.scatter(x_pts, y_pts, c=vals, cmap=cmap)
                plot_code += f"scatter = ax.scatter(x_pts, y_pts, c={vals}, cmap={cmap})\n"

                leg = self.ax.legend(
                    handles=scatter.legend_elements(num=range(0, len(unique)))[0],
                    labels=unique,
                    title=color_attr_name,
                    markerscale=2.0,
                    bbox_to_anchor=(1.05, 1),
                    loc="upper left",
                    ncol=1,
                    frameon=False,
                    fontsize="13",
                )
                scatter.set_alpha(0.5)
                plot_code += f"""ax.legend(
                    handles=scatter.legend_elements(num=range(0, len({unique})))[0],
                    labels={unique},
                    title='{color_attr_name}', 
                    markerscale=2.,
                    bbox_to_anchor=(1.05, 1), 
                    loc='upper left', 
                    ncol=1, 
                    frameon=False,
                    fontsize='13')\n"""
                plot_code += "scatter.set_alpha(0.5)\n"
        else:
            set_fig_code = "fig, ax = plt.subplots(figsize=(4.5, 4))\n"
            self.ax.scatter(x_pts, y_pts, alpha=0.5)
            plot_code += f"ax.scatter(x_pts, y_pts, alpha=0.5)\n"
        self.ax.set_xlabel(x_attr_abv, fontsize="15")
        self.ax.set_ylabel(y_attr_abv, fontsize="15")

        self.code += "import numpy as np\n"
        self.code += "from math import nan\n"
        self.code += "from matplotlib.cm import ScalarMappable\n"

        self.code += set_fig_code
        self.code += f"x_pts = df['{x_attr.attribute}']\n"
        self.code += f"y_pts = df['{y_attr.attribute}']\n"

        self.code += plot_code
        self.code += f"ax.set_xlabel('{x_attr_abv}', fontsize='15')\n"
        self.code += f"ax.set_ylabel('{y_attr_abv}', fontsize='15')\n"
Exemple #33
0
def plot_basemap(lons, lats, size, color, labels=None,
                 projection='cyl', resolution='l', continent_fill_color='0.8',
                 water_fill_color='1.0', colormap=None, colorbar=None,
                 marker="o", title=None, colorbar_ticklabel_format=None,
                 show=True, **kwargs):  # @UnusedVariable
    """
    Creates a basemap plot with a data point scatter plot.

    :type lons: list/tuple of floats
    :param lons: Longitudes of the data points.
    :type lats: list/tuple of floats
    :param lats: Latitudes of the data points.
    :type size: float or list/tuple of floats
    :param size: Size of the individual points in the scatter plot.
    :type color: float or list/tuple of
        floats/:class:`~obspy.core.utcdatetime.UTCDateTime`
    :param color: Color information of the individual data points. Can be
    :type labels: list/tuple of str
    :param labels: Annotations for the individual data points.
    :type projection: str, optional
    :param projection: The map projection. Currently supported are
        * ``"cyl"`` (Will plot the whole world.)
        * ``"ortho"`` (Will center around the mean lat/long.)
        * ``"local"`` (Will plot around local events)
        Defaults to "cyl"
    :type resolution: str, optional
    :param resolution: Resolution of the boundary database to use. Will be
        based directly to the basemap module. Possible values are
        * ``"c"`` (crude)
        * ``"l"`` (low)
        * ``"i"`` (intermediate)
        * ``"h"`` (high)
        * ``"f"`` (full)
        Defaults to ``"l"``
    :type continent_fill_color: Valid matplotlib color, optional
    :param continent_fill_color:  Color of the continents. Defaults to
        ``"0.9"`` which is a light gray.
    :type water_fill_color: Valid matplotlib color, optional
    :param water_fill_color: Color of all water bodies.
        Defaults to ``"white"``.
    :type colormap: str, any matplotlib colormap, optional
    :param colormap: The colormap for color-coding the events.
        The event with the smallest property will have the
        color of one end of the colormap and the event with the biggest
        property the color of the other end with all other events in
        between.
        Defaults to None which will use the default colormap for the date
        encoding and a colormap going from green over yellow to red for the
        depth encoding.
    :type colorbar: bool, optional
    :param colorbar: When left `None`, a colorbar is plotted if more than one
        object is plotted. Using `True`/`False` the colorbar can be forced
        on/off.
    :type title: str
    :param title: Title above plot.
    :type colorbar_ticklabel_format: str or function or
        subclass of :class:`matplotlib.ticker.Formatter`
    :param colorbar_ticklabel_format: Format string or Formatter used to format
        colorbar tick labels.
    :type show: bool
    :param show: Whether to show the figure after plotting or not. Can be used
        to do further customization of the plot before showing it.
    """
    min_color = min(color)
    max_color = max(color)

    if isinstance(color[0], (datetime.datetime, UTCDateTime)):
        datetimeplot = True
        color = [date2num(t) for t in color]
    else:
        datetimeplot = False

    scal_map = ScalarMappable(norm=Normalize(min_color, max_color),
                              cmap=colormap)
    scal_map.set_array(np.linspace(0, 1, 1))

    fig = plt.figure()
    # The colorbar should only be plotted if more then one event is
    # present.

    if colorbar is not None:
        show_colorbar = colorbar
    else:
        if len(lons) > 1 and hasattr(color, "__len__") and \
                not isinstance(color, (str, native_str)):
            show_colorbar = True
        else:
            show_colorbar = False

    if projection == "local":
        ax_x0, ax_width = 0.10, 0.80
    else:
        ax_x0, ax_width = 0.05, 0.90

    if show_colorbar:
        map_ax = fig.add_axes([ax_x0, 0.13, ax_width, 0.77])
        cm_ax = fig.add_axes([ax_x0, 0.05, ax_width, 0.05])
        plt.sca(map_ax)
    else:
        ax_y0, ax_height = 0.05, 0.85
        if projection == "local":
            ax_y0 += 0.05
            ax_height -= 0.05
        map_ax = fig.add_axes([ax_x0, ax_y0, ax_width, ax_height])

    if projection == 'cyl':
        bmap = Basemap(resolution=resolution)
    elif projection == 'ortho':
        bmap = Basemap(projection='ortho', resolution=resolution,
                       area_thresh=1000.0, lat_0=np.mean(lats),
                       lon_0=np.mean(lons))
    elif projection == 'local':
        if min(lons) < -150 and max(lons) > 150:
            max_lons = max(np.array(lons) % 360)
            min_lons = min(np.array(lons) % 360)
        else:
            max_lons = max(lons)
            min_lons = min(lons)
        lat_0 = max(lats) / 2. + min(lats) / 2.
        lon_0 = max_lons / 2. + min_lons / 2.
        if lon_0 > 180:
            lon_0 -= 360
        deg2m_lat = 2 * np.pi * 6371 * 1000 / 360
        deg2m_lon = deg2m_lat * np.cos(lat_0 / 180 * np.pi)
        if len(lats) > 1:
            height = (max(lats) - min(lats)) * deg2m_lat
            width = (max_lons - min_lons) * deg2m_lon
            margin = 0.2 * (width + height)
            height += margin
            width += margin
        else:
            height = 2.0 * deg2m_lat
            width = 5.0 * deg2m_lon
        # do intelligent aspect calculation for local projection
        # adjust to figure dimensions
        w, h = fig.get_size_inches()
        aspect = w / h
        if show_colorbar:
            aspect *= 1.2
        if width / height < aspect:
            width = height * aspect
        else:
            height = width / aspect

        bmap = Basemap(projection='aeqd', resolution=resolution,
                       area_thresh=1000.0, lat_0=lat_0, lon_0=lon_0,
                       width=width, height=height)
        # not most elegant way to calculate some round lats/lons

        def linspace2(val1, val2, N):
            """
            returns around N 'nice' values between val1 and val2
            """
            dval = val2 - val1
            round_pos = int(round(-np.log10(1. * dval / N)))
            # Fake negative rounding as not supported by future as of now.
            if round_pos < 0:
                factor = 10 ** (abs(round_pos))
                delta = round(2. * dval / N / factor) * factor / 2
            else:
                delta = round(2. * dval / N, round_pos) / 2
            new_val1 = np.ceil(val1 / delta) * delta
            new_val2 = np.floor(val2 / delta) * delta
            N = (new_val2 - new_val1) / delta + 1
            return np.linspace(new_val1, new_val2, N)

        N1 = int(np.ceil(height / max(width, height) * 8))
        N2 = int(np.ceil(width / max(width, height) * 8))
        bmap.drawparallels(linspace2(lat_0 - height / 2 / deg2m_lat,
                                     lat_0 + height / 2 / deg2m_lat, N1),
                           labels=[0, 1, 1, 0])
        if min(lons) < -150 and max(lons) > 150:
            lon_0 %= 360
        meridians = linspace2(lon_0 - width / 2 / deg2m_lon,
                              lon_0 + width / 2 / deg2m_lon, N2)
        meridians[meridians > 180] -= 360
        bmap.drawmeridians(meridians, labels=[1, 0, 0, 1])
    else:
        msg = "Projection '%s' not supported." % projection
        raise ValueError(msg)

    # draw coast lines, country boundaries, fill continents.
    plt.gca().set_axis_bgcolor(water_fill_color)
    bmap.drawcoastlines(color="0.4")
    bmap.drawcountries(color="0.75")
    bmap.fillcontinents(color=continent_fill_color,
                        lake_color=water_fill_color)
    # draw the edge of the bmap projection region (the projection limb)
    bmap.drawmapboundary(fill_color=water_fill_color)
    # draw lat/lon grid lines every 30 degrees.
    bmap.drawmeridians(np.arange(-180, 180, 30))
    bmap.drawparallels(np.arange(-90, 90, 30))

    # compute the native bmap projection coordinates for events.
    x, y = bmap(lons, lats)
    # plot labels
    if labels:
        if 100 > len(lons) > 1:
            for name, xpt, ypt, _colorpt in zip(labels, x, y, color):
                # Check if the point can actually be seen with the current bmap
                # projection. The bmap object will set the coordinates to very
                # large values if it cannot project a point.
                if xpt > 1e25:
                    continue
                plt.text(xpt, ypt, name, weight="heavy",
                         color="k", zorder=100, **path_effect_kwargs)
        elif len(lons) == 1:
            plt.text(x[0], y[0], labels[0], weight="heavy", color="k",
                     **path_effect_kwargs)

    scatter = bmap.scatter(x, y, marker=marker, s=size, c=color,
                           zorder=10, cmap=colormap)

    if title:
        plt.suptitle(title)

    # Only show the colorbar for more than one event.
    if show_colorbar:
        if colorbar_ticklabel_format is not None:
            if isinstance(colorbar_ticklabel_format, (str, native_str)):
                formatter = FormatStrFormatter(colorbar_ticklabel_format)
            elif hasattr(colorbar_ticklabel_format, '__call__'):
                formatter = FuncFormatter(colorbar_ticklabel_format)
            elif isinstance(colorbar_ticklabel_format, Formatter):
                formatter = colorbar_ticklabel_format
            locator = MaxNLocator(5)
        else:
            if datetimeplot:
                locator = AutoDateLocator()
                formatter = AutoDateFormatter(locator)
                formatter.scaled[1 / (24. * 60.)] = '%H:%M:%S'
            else:
                locator = None
                formatter = None
        cb = Colorbar(cm_ax, scatter, cmap=colormap,
                      orientation='horizontal',
                      ticks=locator,
                      format=formatter)
        #              format=formatter)
        #              ticks=mpl.ticker.MaxNLocator(4))
        cb.update_ticks()

    if show:
        plt.show()

    return fig
def plot_embedding_v2(X,
                      values,
                      classes=None,
                      method='tSNE',
                      cmap='tab20',
                      figsize=(8, 8),
                      markersize=15,
                      dpi=600,
                      marker=None,
                      return_emb=False,
                      save=False,
                      save_emb=False,
                      show_legend=False,
                      show_axis_label=True,
                      **legend_params):
    if marker is not None:
        X = np.concatenate([X, marker], axis=0)
    N = X.shape[0]
    if X.shape[1] != 2:
        if method == 'tSNE':
            from sklearn.manifold import TSNE
            #X = TSNE(n_components=2, random_state=124,metric='correlation').fit_transform(X)
            X = TSNE(n_components=2, random_state=124).fit_transform(X)
        if method == 'PCA':
            from sklearn.decomposition import PCA
            X = PCA(n_components=2, random_state=124).fit_transform(X)
        if method == 'UMAP':
            from umap import UMAP
            X = UMAP(n_neighbors=15, min_dist=0.1,
                     metric='correlation').fit_transform(X)

    cmap = 'RdBu_r'
    from matplotlib.cm import ScalarMappable
    from matplotlib.colors import Normalize
    sm = ScalarMappable(norm=Normalize(vmin=-np.max(values),
                                       vmax=np.max(values)),
                        cmap='RdBu_r')
    #colors = sns.color_palette('husl', n_colors=len(classes),desat=0.7)
    #colors = sns.husl_palette(len(classes), s=.8)
    for i in range(N):
        plt.scatter(X[i, 0],
                    X[i, 1],
                    s=markersize,
                    color=sm.to_rgba(values[i]))

    legend_params_ = {
        'loc': 'center left',
        'bbox_to_anchor': (1.0, 0.45),
        'fontsize': 20,
        'ncol': 1,
        'frameon': False,
        'markerscale': 1.5
    }
    legend_params_.update(**legend_params)
    if show_legend:
        plt.legend(**legend_params_)
    sns.despine(offset=10, trim=True)
    if show_axis_label:
        plt.xlabel(method + ' dim 1', fontsize=12)
        plt.ylabel(method + ' dim 2', fontsize=12)

    if save:
        plt.savefig(save, format='png', bbox_inches='tight', dpi=dpi)
Exemple #35
0
def plot_depth(image_key, pixel_age, flag_map, # validity_map,
               depth_map_true, depth_map_pred, variance_map,
               image_cmap="gray", depth_cmap='RdBu'):

    fig = plt.figure()

    ax = fig.add_subplot(2, 4, 1)
    ax.set_title("keyframe")
    ax.imshow(image_key, cmap=image_cmap)

    ax = fig.add_subplot(2, 4, 2)
    ax.set_title("pixel age")
    im = ax.imshow(pixel_age, cmap=image_cmap)
    plot_with_bar(ax, im)

    ax = fig.add_subplot(2, 4, 3)
    ax.set_title("flag map")
    ax.imshow(flag_to_color_map(flag_map))
    patches = [Patch("black", flag_to_rgb(f), label=f.name) for f in FLAG]
    ax.legend(handles=patches, loc='center left', bbox_to_anchor=(1.05, 0.5))

    mask = flag_map==FLAG.SUCCESS

    if mask.sum() == 0:  # no success pixel
        plt.show()
        return

    depths_pred = depth_map_pred[mask]
    depths_true = depth_map_true[mask]
    depths_diff = np.abs(depths_pred - depths_true)

    us = image_coordinates(depth_map_pred.shape)[mask.flatten()]

    vmax = np.percentile(
        np.concatenate((depth_map_true.flatten(), depths_pred)), 98
    )
    norm = Normalize(vmin=0.0, vmax=vmax)
    mapper = ScalarMappable(norm=norm, cmap=depth_cmap)

    ax = fig.add_subplot(2, 4, 5)
    ax.set_title("ground truth depth")
    ax.axis("off")
    im = ax.imshow(depth_map_true, norm=norm, cmap=depth_cmap)
    plot_with_bar(ax, im)

    height, width = image_key.shape[0:2]

    ax = fig.add_subplot(2, 4, 6)
    ax.set_title("predicted depth map")
    ax.axis("off")
    im = ax.imshow(image_key, cmap=image_cmap)
    ax.scatter(us[:, 0], us[:, 1], s=0.5, c=mapper.to_rgba(depths_pred))
    ax.set_xlim(0, width)
    ax.set_ylim(height, 0)

    ax = fig.add_subplot(2, 4, 7)
    ax.set_title("error = abs(pred - true)")
    im = ax.imshow(image_key, cmap=image_cmap)
    ax.scatter(us[:, 0], us[:, 1], s=0.5, c=mapper.to_rgba(depths_diff))
    ax.set_xlim(0, width)
    ax.set_ylim(height, 0)

    ax = fig.add_subplot(2, 4, 8)
    ax.set_title("variance map")
    norm = Normalize(vmin=0.0, vmax=np.percentile(variance_map.flatten(), 98))
    im = ax.imshow(variance_map, norm=norm, cmap=depth_cmap)
    plot_with_bar(ax, im)

    plt.show()
Exemple #36
0
def contam(dval):
    """
	Function to plot the contamination against the stellar TESS magnitudes

	.. codeauthor:: Mikkel N. Lund <*****@*****.**>
	.. codeauthor:: Rasmus Handberg <*****@*****.**>
	"""

    logger = logging.getLogger('dataval')
    logger.info('Plotting Contamination vs. Magnitude...')

    xmax = np.arange(0, 21, 1)
    ymax = np.array([
        0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.12, 0.2, 0.3, 0.45, 0.6, 0.7, 0.8, 0.9,
        0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9
    ])
    cont_vs_mag = InterpolatedUnivariateSpline(xmax, ymax, k=1, ext=3)

    for cadence in dval.cadences:

        fig, ax = plt.subplots()  # plt.figaspect(2.0)
        fig.subplots_adjust(left=0.145,
                            wspace=0.3,
                            top=0.945,
                            bottom=0.145,
                            right=0.975)

        # Search database for all targets processed with aperture photometry:
        star_vals = dval.search_database(
            select=[
                'todolist.priority', 'todolist.sector', 'todolist.tmag',
                'contamination'
            ],
            search=["method_used='aperture'", f'cadence={cadence:d}'])

        rgba_color = 'k'
        if dval.color_by_sector:
            sec = np.array([star['sector'] for star in star_vals],
                           dtype='int32')
            sectors = list(set(sec))
            if len(sectors) > 1:
                norm = colors.Normalize(vmin=1, vmax=len(sectors))
                scalarMap = ScalarMappable(norm=norm,
                                           cmap=plt.get_cmap('Set1'))
                rgba_color = np.array([scalarMap.to_rgba(s) for s in sec])

        pri = np.array([star['priority'] for star in star_vals], dtype='int64')
        tmags = np.array([star['tmag'] for star in star_vals], dtype='float64')
        contam = np.array([star['contamination'] for star in star_vals],
                          dtype='float64')

        # Indices for plotting
        idx_invalid = np.isnan(contam)
        with np.errstate(
                invalid='ignore'):  # We know that some cont may be NaN
            idx_high = (contam > 1)
            idx_low = (contam <= 1)

        # Remove nan contaminations (should be only from Halo targets)
        contam[idx_high] = 1.1
        contam[idx_invalid] = 1.2

        # Plot individual contamination points
        ax.scatter(tmags[idx_low],
                   contam[idx_low],
                   marker='o',
                   facecolors=rgba_color,
                   color=rgba_color,
                   alpha=0.1,
                   rasterized=True)

        ax.scatter(tmags[idx_high],
                   contam[idx_high],
                   marker='o',
                   facecolors='None',
                   color=rgba_color,
                   alpha=0.9)

        # Plot invalid points:
        ax.scatter(tmags[idx_invalid],
                   contam[idx_invalid],
                   marker='o',
                   facecolors='None',
                   color='r',
                   alpha=0.9)

        # Indices for finding validation limit
        #idx_low = (contam <= 1)
        # Compute median-bin curve
        #bin_means, bin_edges, binnumber = binning(tmags[idx_low], contam[idx_low], statistic='median', bins=15, range=(np.nanmin(tmags),np.nanmax(tmags)))
        #bin_width = (bin_edges[1] - bin_edges[0])
        #bin_centers = bin_edges[1:] - bin_width/2

        # Plot median-bin curve (1 and 5 times standadised MAD)
        #ax1.scatter(bin_centers, 1.4826*bin_means, marker='o', color='r')
        #ax1.scatter(bin_centers, 1.4826*5*bin_means, marker='.', color='r')
        #ax.plot(xmax, ymax, marker='.', color='r', ls='-')

        mags = np.linspace(dval.tmag_limits[0], dval.tmag_limits[1], 100)
        ax.plot(mags, cont_vs_mag(mags), 'r-')

        ax.set_xlim(dval.tmag_limits)
        ax.set_ylim([-0.05, 1.3])

        ax.axhline(y=0, ls='--', color='k', zorder=-1)
        ax.axhline(y=1.1, ls=':', color='k', zorder=-1)
        ax.axhline(y=1.2, ls=':', color='r', zorder=-1)
        ax.axhline(y=1, ls=':', color='k', zorder=-1)
        ax.set_xlabel('TESS magnitude')
        ax.set_ylabel('Contamination')

        ax.xaxis.set_major_locator(MultipleLocator(2))
        ax.xaxis.set_minor_locator(MultipleLocator(1))
        ax.yaxis.set_major_locator(MultipleLocator(0.2))
        ax.yaxis.set_minor_locator(MultipleLocator(0.1))

        ###########

        fig.savefig(os.path.join(dval.outfolder, f'contam_c{cadence:04d}'))
        if not dval.show:
            plt.close(fig)

        # Assign validation bits
        dv = np.zeros_like(pri, dtype='int32')
        dv[contam >= 1] |= DatavalQualityFlags.InvalidContamination
        dv[(contam > cont_vs_mag(tmags))
           & (contam < 1)] |= DatavalQualityFlags.ContaminationHigh
        dval.update_dataval(pri, dv)
Exemple #37
0
    def draw(self):
        # TODO: recycle the figure with `self.fig.clear()` rather than creating new panel and figure each refresh!

        gs = mpl.gridspec.GridSpec(1, 3, width_ratios=[85, 5, 10], wspace=0.025)

        # --- Main dot scatter plot ---
        self.dot_plot = dot_plot = self.figure.add_subplot(gs[0])

        miny, maxy = self.config['freqminmax']
        plot_kwargs = dict(cmap=self.config['colormap'],
                           vmin=self.SLOPE_MIN, vmax=self.SLOPE_MAX,  # vmin/vmax define where we scale our colormap
                           c=self.slopes, s=self.scaled_amplitudes,   # dot color and size
                           linewidths=0.0,
                           )

        def plot_harmonics(x):
            """Reusable way to plot harmonics from different view types"""
            if self.config['harmonics']['0.5']:
                dot_plot.scatter(x, self.freqs/2, alpha=0.2, **plot_kwargs)
            if self.config['harmonics']['2']:
                dot_plot.scatter(x, self.freqs*2, alpha=0.2, **plot_kwargs)
            if self.config['harmonics']['3']:
                dot_plot.scatter(x, self.freqs*3, alpha=0.2, **plot_kwargs)

        if len(self.freqs) < 2:
            dot_scatter = dot_plot.scatter([], [])  # empty set

        elif not self.config['compressed']:
            # Realtime View
            plot_harmonics(self.times)
            dot_scatter = dot_plot.scatter(self.times, self.freqs, **plot_kwargs)
            dot_plot.set_xlim(self.times[0], self.times[-1])
            dot_plot.set_xlabel('Time (sec)')

        else:
            # Compressed (pseudo-Dot-Per-Pixel) View

            if self.config['pulse_markers']:
                for v in self.zc.get_pulses():
                    dot_plot.axvline(v, linewidth=0.5, color='#808080')

            x = range(len(self.freqs))
            plot_harmonics(x)
            dot_scatter = dot_plot.scatter(x, self.freqs, **plot_kwargs)
            dot_plot.set_xlim(0, len(x))
            dot_plot.set_xlabel('Dot Count')

        try:
            dot_plot.set_yscale(self.config['scale'])  # FIXME: fails with "Data has no positive values" error
        except ValueError:
            log.exception('Failed setting log scale (exception caught)')
            log.error('\ntimes: %s\nfreqs: %s\nslopes: %s', self.times, self.freqs, self.slopes)

        dot_plot.set_title(self.name)
        dot_plot.set_ylabel('Frequency (kHz)')
        dot_plot.set_ylim(miny, maxy)

        # remove the default tick labels, then produce our own instead
        dot_plot.yaxis.set_minor_formatter(mpl.ticker.NullFormatter())
        dot_plot.yaxis.set_major_formatter(mpl.ticker.ScalarFormatter())
        minytick = miny if miny % 10 == 0 else miny + 10 - miny % 10  # round up to next 10kHz tick
        maxytick = maxy if maxy % 10 == 0 else maxy + 10 - maxy % 10
        ticks = range(minytick, maxytick+1, 10)   # labels every 10kHz
        dot_plot.yaxis.set_ticks(ticks)

        dot_plot.set_axisbelow(True)
        dot_plot.grid(axis='y', which='both', linestyle=':')

        for freqk in self.config['markers']:
            dot_plot.axhline(freqk, color='r', linewidth=1.0, zorder=0.9)

        for freqk in self.config['filter_markers']:
            dot_plot.axhline(freqk, color='b', linestyle='--', linewidth=1.1, zorder=0.95)

        # draw X and Y cursor; this may beform better if we can use Wx rather than MatPlotLib, see `wxcursor_demo.py`
        if self.config['display_cursor']:
            self.cursor1 = Cursor(dot_plot, useblit=True, color='black', linewidth=1)

        # experimental rectangle selection
        def onselect(eclick, erelease):
            """eclick and erelease are matplotlib events at press and release"""
            x1, y1 = eclick.xdata, eclick.ydata
            x2, y2 = erelease.xdata, erelease.ydata
            log.debug(' Select  (%.3f,%.1f) -> (%.3f,%.1f)  button: %d' % (x1, y1, x2, y2, eclick.button))
            if self.config['compressed']:
                x1, x2 = self.zc.times[int(round(x1))], self.zc.times[int(round(x2))]
            slope = (np.log2(y2) - np.log2(y1)) / (x2 - x1)
            log.debug('         slope: %.1f oct/sec  (%.1f kHz / %.3f sec)' % (slope, y2 - y1, x2 - x1))

        self.selector = RectangleSelector(dot_plot, onselect, drawtype='box')
        #connect('key_press_event', toggle_selector)

        # --- Colorbar plot ---
        self.cbar_plot = cbar_plot = self.figure.add_subplot(gs[1])
        cbar_plot.set_title('Slope')
        try:
            cbar = self.figure.colorbar(dot_scatter, cax=cbar_plot, ticks=[])
        except TypeError:
            # colorbar() blows up on empty set
            sm = ScalarMappable(cmap=self.config['colormap'])  # TODO: this should probably share colormap code with histogram
            sm.set_array(np.array([self.SLOPE_MIN, self.SLOPE_MAX]))
            cbar = self.figure.colorbar(sm, cax=cbar_plot, ticks=[])
        cbar.ax.set_yticklabels([])

        # --- Hist plot ---

        self.hist_plot = hist_plot = self.figure.add_subplot(gs[2])
        hist_plot.set_title('Freqs')

        bin_min, bin_max = self.config['freqminmax']
        bin_size = 2  # khz  # TODO: make this configurable
        bin_n = int((bin_max - bin_min) / bin_size)
        n, bins, patches = hist_plot.hist(self.freqs, weights=self.amplitudes,
                                          range=self.config['freqminmax'], bins=bin_n,
                                          orientation='horizontal',
                                          edgecolor='black')
        hist_plot.set_yscale(self.config['scale'])
        hist_plot.set_ylim(miny, maxy)
        hist_plot.yaxis.set_major_formatter(mpl.ticker.ScalarFormatter())

        # color histogram bins
        cmap = ScalarMappable(cmap=self.config['colormap'], norm=Normalize(vmin=self.SLOPE_MIN, vmax=self.SLOPE_MAX))
        for bin_start, bin_end, patch in zip(bins[:-1], bins[1:], patches):
            bin_mask = (bin_start <= self.freqs) & (self.freqs < bin_end)
            bin_slopes = self.slopes[bin_mask]
            slope_weights = self.scaled_amplitudes[bin_mask]
            avg_slope = np.average(bin_slopes, weights=slope_weights) if bin_slopes.any() else 0.0
            #avg_slope = np.median(bin_slopes) if bin_slopes.any() else 0
            patch.set_facecolor(cmap.to_rgba(avg_slope))

        hist_plot.yaxis.set_minor_formatter(mpl.ticker.NullFormatter())
        hist_plot.yaxis.set_ticks(ticks)
        hist_plot.yaxis.tick_right()

        hist_plot.xaxis.set_ticks([])

        hist_plot.set_axisbelow(True)
        hist_plot.grid(axis='y', which='both', linestyle=':')

        for freqk in self.config['markers']:
            hist_plot.axhline(freqk, color='r', linewidth=1.0, zorder=0.9)

        for freqk in self.config['filter_markers']:
            hist_plot.axhline(freqk, color='b', linestyle='--', linewidth=1.1, zorder=0.95)

        # draw Y cursor
        if self.config['display_cursor']:
            self.cursor3 = Cursor(hist_plot, useblit=True, color='black', linewidth=1, vertOn=False, horizOn=True)
        data_color[i] = (threshold - means[i] + ci_95[i]) / (2 * ci_95[i])
        if data_color[i] > 1: data_color[i] = 1
        if data_color[i] < 0: data_color[i] = 0 
    my_cmap = plt.cm.get_cmap('RdBu')
    plt.bar(x_ticks, means, width=0.8, yerr=ci_95, capsize=5, color=my_cmap(data_color))
    

###plot bars
updateColors(40000)
plt.xticks(x_ticks, years)
axline = plt.axhline(y=40000, clip_on=False, zorder=1, color='red')
txt = plt.gca().text(x=5.5, y=40000, s=40000, color='red')
plt.title('Confidence Interval Interactivity: \nClick the Chart To Recolor')
#show colorbar
my_cmap = plt.cm.get_cmap('RdBu')
sm = ScalarMappable(cmap=my_cmap)
sm.set_array([])
cbar = plt.colorbar(sm)
cbar.set_label('Color', rotation=270,labelpad=25)
    
###interactivity
def moveAxHLine(event): 
    axline.set_ydata(event.ydata)
    #reset the interest data displayal
    txt.set_position((5.5, event.ydata))
    txt.set_text(event.ydata)
    updateColors(event.ydata)
      
plt.gcf().canvas.mpl_connect('button_press_event', moveAxHLine)

Exemple #39
0
        cases_at_measurement_times = np.interp(days_elapsed + delta_t, times,
                                               cumulative_incidence)

        accept = (
            (np.asarray(min_number_cases) < cases_at_measurement_times) &
            (cases_at_measurement_times < np.asarray(max_number_cases))).all()
        cax = ax.semilogy(times,
                          cumulative_incidence,
                          '.-',
                          color=cmap(R0_grid[j]),
                          alpha=1.0 if accept else 0.25,
                          zorder=1 if not accept else 5)

plot_kwargs = dict(fmt='s', color='k', zorder=10, ecolor='k')
ax.errorbar(48, 2890, xerr=14, yerr=[[2700], [2700]], **plot_kwargs)
ax.errorbar(52, 5000, xerr=14, yerr=[[4000], [4700]], **plot_kwargs)

norm = Normalize(vmin=R0_grid.min(), vmax=R0_grid.max())
cbar = plt.colorbar(mappable=ScalarMappable(norm=norm, cmap=plt.cm.viridis),
                    label='$\mathcal{R}_0$')

ax.set_xticks(np.arange(0, 60, 1), minor=True)

for sp in ['right', 'top']:
    ax.spines[sp].set_visible(False)

ax.set_xlabel('Time [days]')
ax.set_ylabel('Cumulative Incidence')
fig.tight_layout()
fig.savefig('plots/trajectories.pdf', bbox_inches='tight')
plt.show()
Exemple #40
0
def add_collage_colorbar(figure,
                         ax,
                         smfs,
                         vmin,
                         vmax,
                         cbar_vmin=None,
                         cbar_vmax=None,
                         midpoint=None,
                         multicollage=False,
                         colorbar_params={},
                         **kwargs):

    if 'fraction' in colorbar_params:
        fraction = colorbar_params['fraction']
    else:
        fraction = .5

    if 'shrink' in colorbar_params:
        shrink = colorbar_params['shrink']
    else:
        shrink = .5

    if 'aspect' in colorbar_params:
        aspect = colorbar_params['aspect']
    else:
        aspect = 20

    if 'pad' in colorbar_params:
        pad = colorbar_params['pad']
    else:
        pad = .5

    if 'anchor' in colorbar_params:
        anchor = colorbar_params['anchor']
    else:
        anchor = 'C'

    if 'format' in colorbar_params:
        c_format = colorbar_params['format']
    else:
        c_format = '%.2g'

    if 'cmap' not in kwargs:
        cmap = None
    else:
        cmap = kwargs.pop('cmap')

    if cmap is None:
        cmap = get_cmap(plt.rcParamsDefault['image.cmap'])
    else:
        if isinstance(cmap, str):
            cmap = get_cmap(cmap)

    if 'threshold' in kwargs:
        threshold = kwargs['threshold']
    else:
        threshold = None

    # Color bar
    our_cmap = get_cmap(cmap)

    if midpoint is None:
        norm = Normalize(vmin=vmin, vmax=vmax)
    else:
        norm = MidpointNormalize(midpoint=midpoint, vmin=vmin, vmax=vmax)

    nb_ticks = 5
    ticks = np.linspace(vmin, vmax, nb_ticks)
    bounds = np.linspace(vmin, vmax, our_cmap.N)

    if threshold is not None:
        cmaplist = [our_cmap(i) for i in range(our_cmap.N)]

        # set colors to grey for absolute values < threshold
        istart = int(norm(-threshold, clip=True) * (our_cmap.N - 1))
        istop = int(norm(threshold, clip=True) * (our_cmap.N - 1))
        for i in range(istart, istop):
            cmaplist[i] = (0.5, 0.5, 0.5, 1.)
        our_cmap = LinearSegmentedColormap.from_list('Custom cmap', cmaplist,
                                                     our_cmap.N)

    # we need to create a proxy mappable
    proxy_mappable = ScalarMappable(cmap=our_cmap, norm=norm)
    proxy_mappable.set_array(np.concatenate(smfs))

    if multicollage:

        cbar = plt.colorbar(proxy_mappable,
                            ax=ax,
                            ticks=ticks,
                            spacing='proportional',
                            format=c_format,
                            orientation='vertical',
                            anchor=anchor,
                            fraction=fraction,
                            shrink=shrink,
                            aspect=aspect,
                            pad=pad)

    else:

        left = (ax[0][0].get_position().x0 + ax[0][0].get_position().x1) / 2
        right = (ax[0][1].get_position().x0 + ax[0][1].get_position().x1) / 2
        bot = ax[1][0].get_position().y1
        width = right - left

        # [left, bottom, width, height]
        cbaxes = figure.add_axes([left, bot - (.05 / 3), width, .05])

        cbar = plt.colorbar(proxy_mappable,
                            cax=cbaxes,
                            ticks=ticks,
                            spacing='proportional',
                            format=c_format,
                            orientation='horizontal',
                            shrink=1,
                            anchor='C')

    _crop_colorbar(cbar, cbar_vmin, cbar_vmax)

    return figure, ax
Exemple #41
0
def animate_contour(xdata=None,
                    ydata=None,
                    zdata=None,
                    times=None,
                    timescale='s',
                    title='',
                    xlabel='',
                    ylabel='',
                    cmap=cm_hot_desaturated,
                    zlims=None,
                    levels=100,
                    save=False,
                    filename=None):
    nt = len(zdata)

    if xdata is None:
        m = zdata[0].shape[1]
        xdata = [i for i in range(m)]
    if ydata is None:
        n = zdata[0].shape[0]
        ydata = [j for j in range(n)]

    vmin = zlims[0] if zlims is not None else np.min(zdata)
    vmax = zlims[1] if zlims is not None else np.max(zdata)

    norm = Normalize(vmin=vmin, vmax=vmax)

    if times is None:
        title_str = title + f'\n0/{nt}'
    else:
        title_str = title + f'\n{times[0]:4.3f}{timescale}'

    fig, ax = plt.subplots(1, 1)

    fig_title = ax.set_title(title_str)

    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    cax = make_axes_locatable(ax).append_axes('right', '5%', '5%')

    im = ax.contourf(xdata,
                     ydata,
                     zdata[0],
                     cmap=cmap,
                     levels=levels,
                     norm=norm)

    if zlims is not None:
        fig.colorbar(ScalarMappable(norm=norm, cmap=cmap), cax=cax)

    def animate(i):
        if zlims is None:
            im = ax.contourf(xdata, ydata, zdata[i], cmap=cmap, levels=levels)
            cax.clear()
            fig.colorbar(im, cax=cax)
        else:
            im = ax.contourf(xdata,
                             ydata,
                             zdata[i],
                             cmap=cmap,
                             levels=levels,
                             norm=norm)

        if times is None:
            title_str = title + f'\n{i}/{nt}'
        else:
            title_str = title + f'\n{times[i]:4.3f}{timescale}'

        fig_title.set_text(title_str)
        return im,

    anim = FuncAnimation(fig,
                         animate,
                         frames=len(zdata),
                         interval=100,
                         blit=False)

    if save:
        anim.save(filename, writer='ffmpeg')
        print(f'Animation saved as {filename}.')
    else:
        plt.show()
Exemple #42
0
               yerr=(error_low, error_high),
               error_kw={
                   'capsize': 10,
                   'elinewidth': 2,
                   'alpha': 1
               }))

###Create and display textarea widget
txt = wdg.Textarea(value='',
                   placeholder='',
                   description='Y Value:',
                   disabled=False)

### Formats color bar.  Need the scalar mapable to enable use of the color bar.
my_cmap = plt.cm.get_cmap('coolwarm')
sm = ScalarMappable(cmap=my_cmap, norm=plt.Normalize(0, 1))
sm.set_array([])
cbar = plt.colorbar(sm)
cbar.set_label('Probability', rotation=270, labelpad=25)

ydataselect = 40000


class ClickChart(object):
    def __init__(self, ax):
        self.fig = ax.figure
        self.ax = ax
        self.horiz_line = ax.axhline(y=ydataselect, color='black', linewidth=2)
        self.fig.canvas.mpl_connect('button_press_event', self.onclick)

Exemple #43
0
    m.fillcontinents(color='#f2f2f2', lake_color='#46bcec')
    m.drawcoastlines()
    #contains all state position coordinates
    m.readshapefile("INDIA/IND_adm1", "INDIA")

    colors = {}
    statenames = []
    patches = []
    #Colormap
    cmap = plt.cm.Reds
    #Colorbar Range
    vmin = min(color_dict.values())
    vmax = max(color_dict.values())
    norm = Normalize(vmin=vmin, vmax=vmax)
    # color mapper to covert values to colors
    mapper = ScalarMappable(norm=norm, cmap=cmap)

    for shapedict in m.INDIA_info:
        statename = shapedict['NAME_1']
        #To incorporate difference between Map State Name and Data Loc Name
        if statename == "Telangana":
            statename = "Telengana"
        if statename in color_dict:
            pop = color_dict[statename]
            colors[statename] = mapper.to_rgba(pop)
            statenames.append(statename)

    for nshape, seg in enumerate(m.INDIA):
        color = rgb2hex(colors[statenames[nshape]])
        poly = Polygon(seg, facecolor=color, edgecolor=color)
        ax.add_patch(poly)
Exemple #44
0
def plot_surf(surf_mesh,
              surf_map=None,
              bg_map=None,
              hemi='left',
              view='lateral',
              cmap=None,
              colorbar=False,
              avg_method='mean',
              threshold=None,
              alpha='auto',
              bg_on_data=False,
              darkness=1,
              vmin=None,
              vmax=None,
              cbar_vmin=None,
              cbar_vmax=None,
              title=None,
              output_file=None,
              axes=None,
              figure=None,
              **kwargs):
    """ Plotting of surfaces with optional background and data

    .. versionadded:: 0.3

    Parameters
    ----------
    surf_mesh: str or list of two numpy.ndarray
        Surface mesh geometry, can be a file (valid formats are
        .gii or Freesurfer specific files such as .orig, .pial,
        .sphere, .white, .inflated) or
        a list of two Numpy arrays, the first containing the x-y-z coordinates
        of the mesh vertices, the second containing the indices
        (into coords) of the mesh faces.

    surf_map: str or numpy.ndarray, optional.
        Data to be displayed on the surface mesh. Can be a file (valid formats
        are .gii, .mgz, .nii, .nii.gz, or Freesurfer specific files such as
        .thickness, .curv, .sulc, .annot, .label) or
        a Numpy array with a value for each vertex of the surf_mesh.

    bg_map: Surface data object (to be defined), optional,
        Background image to be plotted on the mesh underneath the
        surf_data in greyscale, most likely a sulcal depth map for
        realistic shading.

    hemi : {'left', 'right'}, default is 'left'
        Hemisphere to display.

    view: {'lateral', 'medial', 'dorsal', 'ventral', 'anterior', 'posterior'},
        default is 'lateral'
        View of the surface that is rendered.

    cmap: matplotlib colormap, str or colormap object, default is None
        To use for plotting of the stat_map. Either a string
        which is a name of a matplotlib colormap, or a matplotlib
        colormap object. If None, matplotlib default will be chosen

    colorbar : bool, optional, default is False
        If True, a colorbar of surf_map is displayed.

    avg_method: {'mean', 'median'}, default is 'mean'
        How to average vertex values to derive the face value, mean results
        in smooth, median in sharp boundaries.

    threshold : a number or None, default is None.
        If None is given, the image is not thresholded.
        If a number is given, it is used to threshold the image, values
        below the threshold (in absolute value) are plotted as transparent.

    alpha: float, alpha level of the mesh (not surf_data), default 'auto'
        If 'auto' is chosen, alpha will default to .5 when no bg_map
        is passed and to 1 if a bg_map is passed.

    bg_on_data: bool, default is False
        If True, and a bg_map is specified, the surf_data data is multiplied
        by the background image, so that e.g. sulcal depth is visible beneath
        the surf_data.
        NOTE: that this non-uniformly changes the surf_data values according
        to e.g the sulcal depth.

    darkness: float, between 0 and 1, default is 1
        Specifying the darkness of the background image.
        1 indicates that the original values of the background are used.
        .5 indicates the background values are reduced by half before being
        applied.

    vmin, vmax: lower / upper bound to plot surf_data values
        If None , the values will be set to min/max of the data

    title : str, optional
        Figure title.

    output_file: str, or None, optional
        The name of an image file to export plot to. Valid extensions
        are .png, .pdf, .svg. If output_file is not None, the plot
        is saved to a file, and the display is closed.

    axes: instance of matplotlib axes, None, optional
        The axes instance to plot to. The projection must be '3d' (e.g.,
        `figure, axes = plt.subplots(subplot_kw={'projection': '3d'})`,
        where axes should be passed.).
        If None, a new axes is created.

    figure: instance of matplotlib figure, None, optional
        The figure instance to plot to. If None, a new figure is created.

    See Also
    --------
    nilearn.datasets.fetch_surf_fsaverage : For surface data object to be
        used as background map for this plotting function.

    nilearn.plotting.plot_surf_roi : For plotting statistical maps on brain
        surfaces.

    nilearn.plotting.plot_surf_stat_map : for plotting statistical maps on
        brain surfaces.
    """

    # load mesh and derive axes limits
    mesh = load_surf_mesh(surf_mesh)
    coords, faces = mesh[0], mesh[1]
    limits = [coords.min(), coords.max()]

    # set view
    if hemi == 'right':
        if view == 'lateral':
            elev, azim = 0, 0
        elif view == 'medial':
            elev, azim = 0, 180
        elif view == 'dorsal':
            elev, azim = 90, 0
        elif view == 'ventral':
            elev, azim = 270, 0
        elif view == 'anterior':
            elev, azim = 0, 90
        elif view == 'posterior':
            elev, azim = 0, 270
        else:
            raise ValueError('view must be one of lateral, medial, '
                             'dorsal, ventral, anterior, or posterior')
    elif hemi == 'left':
        if view == 'medial':
            elev, azim = 0, 0
        elif view == 'lateral':
            elev, azim = 0, 180
        elif view == 'dorsal':
            elev, azim = 90, 0
        elif view == 'ventral':
            elev, azim = 270, 0
        elif view == 'anterior':
            elev, azim = 0, 90
        elif view == 'posterior':
            elev, azim = 0, 270
        else:
            raise ValueError('view must be one of lateral, medial, '
                             'dorsal, ventral, anterior, or posterior')
    else:
        raise ValueError('hemi must be one of right or left')

    # set alpha if in auto mode
    if alpha == 'auto':
        if bg_map is None:
            alpha = .5
        else:
            alpha = 1

    # if no cmap is given, set to matplotlib default
    if cmap is None:
        cmap = plt.cm.get_cmap(plt.rcParamsDefault['image.cmap'])
    else:
        # if cmap is given as string, translate to matplotlib cmap
        if isinstance(cmap, str):
            cmap = plt.cm.get_cmap(cmap)

    # initiate figure and 3d axes
    if axes is None:
        if figure is None:
            figure = plt.figure()
        axes = Axes3D(figure, rect=[0, 0, 1, 1], xlim=limits, ylim=limits)
    else:
        if figure is None:
            figure = axes.get_figure()
        axes.set_xlim(*limits)
        axes.set_ylim(*limits)
    axes.view_init(elev=elev, azim=azim)
    axes.set_axis_off()

    # plot mesh without data
    p3dcollec = axes.plot_trisurf(coords[:, 0],
                                  coords[:, 1],
                                  coords[:, 2],
                                  triangles=faces,
                                  linewidth=0.,
                                  antialiased=False,
                                  color='white')

    # reduce viewing distance to remove space around mesh
    axes.dist = 8

    # set_facecolors function of Poly3DCollection is used as passing the
    # facecolors argument to plot_trisurf does not seem to work
    face_colors = np.ones((faces.shape[0], 4))

    if bg_map is None:
        bg_data = np.ones(coords.shape[0]) * 0.5

    else:
        bg_data = load_surf_data(bg_map)
        if bg_data.shape[0] != coords.shape[0]:
            raise ValueError('The bg_map does not have the same number '
                             'of vertices as the mesh.')

    bg_faces = np.mean(bg_data[faces], axis=1)
    if bg_faces.min() != bg_faces.max():
        bg_faces = bg_faces - bg_faces.min()
        bg_faces = bg_faces / bg_faces.max()
    # control background darkness
    bg_faces *= darkness
    face_colors = plt.cm.gray_r(bg_faces)

    # modify alpha values of background
    face_colors[:, 3] = alpha * face_colors[:, 3]
    # should it be possible to modify alpha of surf data as well?

    if surf_map is not None:
        surf_map_data = load_surf_data(surf_map)
        if surf_map_data.ndim != 1:
            raise ValueError('surf_map can only have one dimension but has'
                             '%i dimensions' % surf_map_data.ndim)
        if surf_map_data.shape[0] != coords.shape[0]:
            raise ValueError('The surf_map does not have the same number '
                             'of vertices as the mesh.')

        # create face values from vertex values by selected avg methods
        if avg_method == 'mean':
            surf_map_faces = np.mean(surf_map_data[faces], axis=1)
        elif avg_method == 'median':
            surf_map_faces = np.median(surf_map_data[faces], axis=1)

        # if no vmin/vmax are passed figure them out from data
        if vmin is None:
            vmin = np.nanmin(surf_map_faces)
        if vmax is None:
            vmax = np.nanmax(surf_map_faces)

        # treshold if inidcated
        if threshold is None:
            kept_indices = np.arange(surf_map_faces.shape[0])
        else:
            kept_indices = np.where(np.abs(surf_map_faces) >= threshold)[0]

        surf_map_faces = surf_map_faces - vmin
        surf_map_faces = surf_map_faces / (vmax - vmin)

        # multiply data with background if indicated
        if bg_on_data:
            face_colors[kept_indices] = cmap(surf_map_faces[kept_indices])\
                * face_colors[kept_indices]
        else:
            face_colors[kept_indices] = cmap(surf_map_faces[kept_indices])

        if colorbar:
            our_cmap = get_cmap(cmap)
            norm = Normalize(vmin=vmin, vmax=vmax)

            nb_ticks = 5
            ticks = np.linspace(vmin, vmax, nb_ticks)
            bounds = np.linspace(vmin, vmax, our_cmap.N)

            if threshold is not None:
                cmaplist = [our_cmap(i) for i in range(our_cmap.N)]
                # set colors to grey for absolute values < threshold
                istart = int(norm(-threshold, clip=True) * (our_cmap.N - 1))
                istop = int(norm(threshold, clip=True) * (our_cmap.N - 1))
                for i in range(istart, istop):
                    cmaplist[i] = (0.5, 0.5, 0.5, 1.)
                our_cmap = LinearSegmentedColormap.from_list(
                    'Custom cmap', cmaplist, our_cmap.N)

            # we need to create a proxy mappable
            proxy_mappable = ScalarMappable(cmap=our_cmap, norm=norm)
            proxy_mappable.set_array(surf_map_faces)
            cax, kw = make_axes(axes,
                                location='right',
                                fraction=.1,
                                shrink=.6,
                                pad=.0)
            cbar = figure.colorbar(proxy_mappable,
                                   cax=cax,
                                   ticks=ticks,
                                   boundaries=bounds,
                                   spacing='proportional',
                                   format='%.2g',
                                   orientation='vertical')
            _crop_colorbar(cbar, cbar_vmin, cbar_vmax)

        p3dcollec.set_facecolors(face_colors)

    if title is not None:
        axes.set_title(title, position=(.5, .95))

    # save figure if output file is given
    if output_file is not None:
        figure.savefig(output_file)
        plt.close(figure)
    else:
        return figure
Exemple #45
0
def visualize(samples, labels):
    """Main function.
    
    Note:
        First init plots then update per user keypress.
        Samples are in the form [(xz, yz, xy), ...] in range [0, common.RADAR_MAX].

    Args:
        samples (list of tuples of np.array): Data to visualize.
        labels (list of strings): Sample labels.
    """
    fig = plt.figure()
    gs = fig.add_gridspec(2, 2)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[1, :])

    # Get position maps.
    pmap_xz, pmap_yz = gen_pos_map()

    idx = 0

    # Initial sample.
    xz, yz, xy = samples[idx]

    title = fig.suptitle(
        f'Target Return Signal. Label "{labels[idx]}", Sample {idx}.')

    # Setup x-z plane plot.
    # Radar target return signal strength in x-z plane.
    init_axis(ax1, 'X-Z Plane', 'X (cm)', 'Z (cm)')
    sm = ScalarMappable(cmap='coolwarm')
    init_c = sm.to_rgba(xz.T.flatten())
    pts_xz = ax1.scatter(pmap_xz[0],
                         pmap_xz[1],
                         s=pmap_xz[2],
                         c=init_c,
                         cmap='coolwarm',
                         zorder=1)

    # Setup y-z plane plot.
    # Radar target return signal strength in y-z plane.
    init_axis(ax2, 'Y-Z Plane', 'Y (cm)', 'Z (cm)')
    sm = ScalarMappable(cmap='coolwarm')
    init_c = sm.to_rgba(yz.T.flatten())
    pts_yz = ax2.scatter(pmap_yz[0],
                         pmap_yz[1],
                         s=pmap_yz[2],
                         c=init_c,
                         cmap='coolwarm',
                         zorder=1)

    # Setup x-y plane plot.
    # Radar target return signal strength in x-z plane.
    init_axis(ax3, 'X-Y Plane', 'X (cm)', 'Y (cm)')

    # Calculate axis range to set axis limits and plot extent.
    xmax = np.amax(pmap_xz[0]).astype(np.int)
    xmin = np.amin(pmap_xz[0]).astype(np.int)
    ymax = np.amax(pmap_yz[0]).astype(np.int)
    ymin = np.amin(pmap_yz[0]).astype(np.int)
    zmax = np.amax(pmap_yz[1]).astype(np.int)
    zmin = np.amin(pmap_yz[1]).astype(np.int)

    ax3.set_xlim(xmax, xmin)
    ax3.set_ylim(ymax, ymin)
    # Rotate xy image if radar horizontal since x and y axis are rotated 90 deg CCW.
    if RADAR_HORIZONTAL:
        xy = np.rot90(xy)
    sm = ScalarMappable(cmap='coolwarm')
    init_img = sm.to_rgba(xy)
    pts_xy = ax3.imshow(init_img,
                        cmap='coolwarm',
                        extent=[xmin, xmax, ymin, ymax],
                        zorder=1)

    def update(event):
        """Update plot per keypress."""
        nonlocal idx

        if event.key == 'n':  # next sample
            if idx >= len(samples) - 1:
                idx = len(samples) - 1
            else:
                idx += 1
        elif event.key == 'b':  # prev sample
            if idx <= 0:
                idx = 0
            else:
                idx -= 1
        elif event.key == 'escape':  # all done
            plt.close()
            return

        # Get next sample.
        xz, yz, xy = samples[idx]

        # Update title.
        title.set_text(
            f'Target Return Signal. Label "{labels[idx]}", Sample {idx}.')

        # Update image colors according to return signal strength on plots.
        sm = ScalarMappable(cmap='coolwarm')
        pts_xz.set_color(sm.to_rgba(xz.T.flatten()))

        sm = ScalarMappable(cmap='coolwarm')
        pts_yz.set_color(sm.to_rgba(yz.T.flatten()))

        if RADAR_HORIZONTAL:
            xy = np.rot90(xy)
        sm = ScalarMappable(cmap='coolwarm')
        pts_xy.set_data(sm.to_rgba(xy))

        # Actual update of plot.
        plt.draw()

        return

    fig.canvas.mpl_connect('key_press_event', update)

    plt.show()

    return
Exemple #46
0
    bin_width = ticks[1] - ticks[0]
    bin_centers = np.array(ticks[:-1]) + bin_width / 2
    hist_xrange_kwargs = dict(start=min(ticks), end=max(ticks))
else:
    levels = MaxNLocator(nbins=config.NBINS).tick_values(
        config.MIN_VAL, config.MAX_VAL)
    ticks = levels[::3]
    cbkwargs = {}
    cmkwargs = dict(low=config.MIN_VAL, high=config.MAX_VAL)
    bin_width = levels[1] - levels[0]
    bin_centers = levels[:-1] + bin_width / 2
    hist_xrange_kwargs = dict(start=config.MIN_VAL, end=config.MAX_VAL)

cmap = get_cmap(config.CMAP)
norm = BoundaryNorm(levels, ncolors=cmap.N, clip=True)
sm = ScalarMappable(norm=norm, cmap=cmap)
color_pal = [
    RGB(*val).to_hex()
    for val in sm.to_rgba(levels, bytes=True, norm=True)[:-1]
]

bin_pal = color_pal.copy()
bin_pal.append('#ffffff')
bin_mapper = BinnedColorMapper(bin_pal, alpha=config.ALPHA)
color_mapper = LinearColorMapper(color_pal, **cmkwargs)
ticker = FixedTicker(ticks=ticks)
cb = ColorBar(color_mapper=color_mapper,
              location=(0, 0),
              scale_alpha=config.ALPHA,
              ticker=ticker,
              **cbkwargs)
Exemple #47
0
        rasterized=True,
    )
    for sp in ax.spines.values():
        sp.set_visible(False)

    plt.xticks([])
    plt.yticks([])
    plt.title(top_gene)

# colorbars

from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable

cbax = fig.add_axes([.93, .6, .01, .2])
plt.colorbar(ScalarMappable(norm=Normalize(), cmap="viridis"),
             cax=cbax,
             ticks=[-1, 0, 1],
             label="Module Score")

cbax = fig.add_axes([.93, .2, .01, .2])
plt.colorbar(
    ScalarMappable(norm=Normalize(), cmap=gene_cmap),
    cax=cbax,
    ticks=[-1, 0, 1],
    label="Normalized Log\nExpression",
)

plt.show()
# plt.savefig("Module_Scores.svg", dpi=300)
Exemple #48
0
def show_graph_with_labels(corr_mat,
                           th,
                           trps,
                           halftimes,
                           prop=None,
                           path=None,
                           base_name=None):

    # Get a 0-1 matrix for whether corr > threshold
    adjacency_matrix = np.zeros(corr_mat.shape)
    adjacency_matrix[corr_mat >= th] = 1
    adjacency_matrix[corr_mat < th] = 0

    rows, cols = np.where(adjacency_matrix == 1)
    rows = rows.tolist()
    cols = cols.tolist()
    edges = []
    # keysm.set_array([])
    for i in range(len(rows)):
        d = {'corr': corr_mat[rows[i], cols[i]]}
        e = (trps[rows[i]], trps[cols[i]], d)
        edges.append(e)

    gr = nx.Graph()
    gr.add_edges_from(edges)

    edges, corrs = zip(*nx.get_edge_attributes(gr, 'corr').items())
    corrs = np.array(corrs)
    if prop == None:
        pos = nx.kamada_kawai_layout(gr, weight='corr')
    else:
        pos = {}
        for node in gr.nodes():
            pos[node] = (halftimes[node]['X'][prop],
                         halftimes[node]['Y'][prop])
    hts = [halftimes[node]['lambda'] for node in gr.nodes()]
    hts = np.arcsinh(1 * hts) / 1
    node_cmap = plt.cm.coolwarm_r
    nodenorm = plt.Normalize(np.min(hts), np.max(hts))
    nodesm = ScalarMappable(norm=nodenorm, cmap=node_cmap)
    nodesm.set_array([])

    low_lim = np.min(corrs)
    max_lim = np.max(corrs[corrs < 1.0])
    edge_cmap = plt.cm.Greys
    edgenorm = plt.Normalize(low_lim, max_lim)
    edgesm = ScalarMappable(norm=edgenorm, cmap=edge_cmap)
    edgesm.set_array([])

    fig = plt.figure(figsize=[10, 6])
    nx.draw_networkx(gr,
                     node_size=100,
                     with_labels=True,
                     pos=pos,
                     alpha=0.9,
                     node_color=hts,
                     cmap=node_cmap,
                     vmin=np.min(hts),
                     vmax=np.max(hts),
                     edge_color=corrs,
                     edge_cmap=edge_cmap,
                     edge_vmin=low_lim,
                     edge_vmax=max_lim,
                     width=1.2,
                     linewidths=2,
                     font_size=5,
                     font_color='black',
                     font_weight='semibold')
    corr_cbar = plt.colorbar(edgesm, orientation="vertical")
    corr_cbar.ax.set_title('Correlation')

    ht_cbar = plt.colorbar(nodesm, orientation='vertical')
    ht_cbar.ax.set_title('Reaction constant')

    plt.title('Correlation between tripeptides')
    if prop != None:
        plt.xlabel('X ' + prop)
        plt.ylabel('Z ' + prop)

    if path != None and base_name != None:
        out_file = path + base_name + '_' + prop + '_graph.png'
        plt.savefig(out_file, dpi=400)
    else:
        plt.show()
    return (gr)
Exemple #49
0
              zorder=10,
              transform=ccrs.PlateCarree())

# Plot buffer that shows where we got cross section stuff from
_ = plot_line_buffer(qlat,
                     qlon,
                     delta=epi_area,
                     zorder=5,
                     linestyle='--',
                     linewidth=1.0,
                     alpha=1.0,
                     facecolor='none',
                     edgecolor='k')

# Create colorbar from artifical scalarmappable (alpha is problematic)
c = colorbar(ScalarMappable(cmap=get_cmap('magma'), norm=illumnorm),
             orientation='vertical',
             aspect=40)

# Set colormap alpha manually
c.solids.set(alpha=0.5)

# ############### Plot section ###################

# Plot section
rfcmap = 'seismic'

figure()
ax = axes(facecolor=(0.8, 0.8, 0.8))

# Plot section
def main(argv):
    parser = argparse.ArgumentParser(
        description="Plot the loss evolving through time")

    parser.add_argument("metrics",
                        type=file_or_stdin,
                        help="The file containing the loss")

    parser.add_argument("--to_file", help="Save the animation to a video file")
    parser.add_argument("--step",
                        type=int,
                        default=1000,
                        help="Change that many datapoints in between frames")
    parser.add_argument("--n_points",
                        type=int,
                        default=10000,
                        help="That many points in each frame")
    parser.add_argument("--frames",
                        type=lambda x: slice(*map(maybe_int, x.split(":"))),
                        default=":",
                        help="Choose only those frames")
    parser.add_argument("--lim",
                        type=lambda x: map(float, x.split(",")),
                        help="Define the limits of the axes")
    parser.add_argument("--no_colorbar",
                        action="store_false",
                        dest="colorbar",
                        help="Do not display a colorbar")

    args = parser.parse_args(argv)
    loss = np.loadtxt(args.metrics)

    fig, ax = plt.subplots()
    lr = LinearRegression()
    sc = ax.scatter(loss[:args.n_points, 0],
                    loss[:args.n_points, 1],
                    c=colors(loss[:args.n_points, 2]))
    lims = args.lim if args.lim else [0, loss[:, 0].max()]
    ln, = ax.plot(lims, lims, "--", color="black", label="linear fit")
    ax.set_xlim(lims)
    ax.set_ylim(lims)
    ax.set_xlabel("$L(\cdot)$")
    ax.set_ylabel("$\hat{L}(\cdot)$")
    if args.colorbar:
        mappable = ScalarMappable(cmap="viridis")
        mappable.set_array(loss[:10000, 2])
        plt.colorbar(mappable)

    STEP = args.step
    N_POINTS = args.n_points

    def update(i):
        s = i * STEP
        e = s + N_POINTS
        lr.fit(loss[s:e, :1], loss[s:e, 1].ravel())
        ln.set_ydata([
            lr.intercept_.ravel(),
            lr.intercept_.ravel() + lims[1] * lr.coef_.ravel()
        ])
        ax.set_title("Showing samples %d to %d" % (s, e))
        sc.set_facecolor(colors(loss[s:e, 2]))
        sc.set_offsets(loss[s:e, :2])
        return ax, sc, ln

    anim = FuncAnimation(fig,
                         update,
                         interval=100,
                         frames=np.arange(len(loss) / STEP)[args.frames],
                         blit=False,
                         repeat=False)
    if args.to_file:
        writer = animation_writers["ffmpeg"](fps=15)
        anim.save(args.to_file, writer=writer)
    else:
        plt.show()
Exemple #51
0
def parallel_axes_plot(archive,
                       ax=None,
                       bc_order=None,
                       cmap="magma",
                       linewidth=1.5,
                       alpha=0.8,
                       vmin=None,
                       vmax=None,
                       sort_archive=False,
                       cbar_orientation='horizontal',
                       cbar_pad=0.1):
    """Visualizes archive entries in behavior space with a parallel axes plot.

    This visualization is meant to show the coverage of the behavior space at a
    glance. Each axis represents one behavioral dimension, and each line in the
    diagram represents one entry in the archive. Three main things are evident
    from this plot:

    - **Behavior space coverage,** as determined by the amount of the axis that
      has lines passing through it. If the lines are passing through all parts
      of the axis, then there is likely good coverage for that BC.

    - **Correlation between neighboring BCs.** In the below example, we see
      perfect correlation between ``behavior_0`` and ``behavior_1``, since none
      of the lines cross each other. We also see the perfect negative
      correlation between ``behavior_3`` and ``behavior_4``, indicated by the
      crossing of all lines at a single point.

    - **Whether certain values of the behavior dimensions affect the objective
      value strongly.** In the below example, we see ``behavior_2`` has many
      entries with high objective near zero. This is more visible when
      ``sort_archive`` is passed in, as entries with higher objective values
      will be plotted on top of individuals with lower objective values.

    Examples:
        .. plot::
            :context: close-figs

            >>> import numpy as np
            >>> import matplotlib.pyplot as plt
            >>> from ribs.archives import GridArchive
            >>> from ribs.visualize import parallel_axes_plot
            >>> # Populate the archive with the negative sphere function.
            >>> archive = GridArchive(
            ...               [20, 20, 20, 20, 20],
            ...               [(-1, 1), (-1, 1), (-1, 1), (-1, 1), (-1, 1)]
            ...           )
            >>> archive.initialize(solution_dim=3)
            >>> for x in np.linspace(-1, 1, 100):
            ...     for y in np.linspace(0, 1, 100):
            ...         for z in np.linspace(-1, 1, 100):
            ...             archive.add(
            ...                 solution=np.array([x,y,z]),
            ...                 objective_value=-(x**2 + y**2 + z**2),
            ...                 behavior_values=np.array([0.5*x,x,y,z,-0.5*z]),
            ...             )
            >>> # Plot a heatmap of the archive.
            >>> plt.figure(figsize=(8, 6))
            >>> parallel_axes_plot(archive)
            >>> plt.title("Negative sphere function")
            >>> plt.ylabel("axis values")
            >>> plt.show()

    Args:
        archive (ArchiveBase): Any ribs archive.
        ax (matplotlib.axes.Axes): Axes on which to create the plot.
            If None, the current axis will be used.
        bc_order (list of int or list of (int, str)): If this is a list of ints,
            it specifies the axes order for BCs (e.g. ``[2, 0, 1]``). If this is
            a list of tuples, each tuple takes the form ``(int, str)`` where the
            int specifies the BC index and the str specifies a name for the BC
            (e.g. ``[(1, "y-value"), (2, "z-value"), (0, "x-value)]``). The
            order specified does not need to have the same number of elements as
            the number of behaviors in the archive, e.g. ``[1, 3]`` or
            ``[1, 2, 3, 2]``.
        cmap (str, list, matplotlib.colors.Colormap): Colormap to use when
            plotting intensity. Either the name of a colormap, a list of RGB or
            RGBA colors (i.e. an Nx3 or Nx4 array), or a colormap object.
        linewidth (float): Line width for each entry in the plot.
        alpha (float): Opacity of the line for each entry (passing a low value
            here may be helpful if there are many archive entries, as more
            entries would be visible).
        vmin (float): Minimum objective value to use in the plot. If None, the
            minimum objective value in the archive is used.
        vmax (float): Maximum objective value to use in the plot. If None, the
            maximum objective value in the archive is used.
        sort_archive (boolean): if true, sorts the archive so that the highest
            performing entries are plotted on top of lower performing entries.

            .. warning:: This may be slow for large archives.
        cbar_orientation (str): The orientation of the colorbar. Use either
            ``'vertical'`` or ``'horizontal'``
        cbar_pad (float): The amount of padding to use for the colorbar.

    Raises:
        ValueError: ``cbar_orientation`` has an invalid value.
        ValueError: The bcs provided do not exist in the archive.
        TypeError: bc_order is not a list of all ints or all tuples.
    """
    # Try getting the colormap early in case it fails.
    cmap = _retrieve_cmap(cmap)

    # Check that the orientation input is correct.
    if cbar_orientation not in ['vertical', 'horizontal']:
        raise ValueError("cbar_orientation mus be 'vertical' or 'horizontal' "
                         f"but is '{cbar_orientation}'")

    # If there is no order specified, plot in increasing numerical order.
    if bc_order is None:
        cols = [f"behavior_{i}" for i in range(archive.behavior_dim)]
        axis_labels = cols
        lower_bounds = archive.lower_bounds
        upper_bounds = archive.upper_bounds

    # Use the requested behaviors (may be less than the original number of bcs).
    else:
        # Check for errors in specification.
        if all(isinstance(bc, int) for bc in bc_order):
            bc_indices = np.array(bc_order)
            axis_labels = [f"behavior_{i}" for i in bc_indices]
        elif all(
                len(bc) == 2 and isinstance(bc[0], int)
                and isinstance(bc[1], str) for bc in bc_order):
            bc_indices, axis_labels = zip(*bc_order)
            bc_indices = np.array(bc_indices)
        else:
            raise TypeError("bc_order must be a list of ints or a list of"
                            "tuples in the form (int, str)")

        if np.max(bc_indices) >= archive.behavior_dim:
            raise ValueError(f"Invalid Behavior: requested behavior index "
                             f"{np.max(bc_indices)}, but archive only has "
                             f"{archive.behavior_dim} behaviors.")
        if any(bc < 0 for bc in bc_indices):
            raise ValueError("Invalid Behavior: requested a negative behavior"
                             " index.")

        # Find the indices of the requested order.
        cols = [f"behavior_{i}" for i in bc_indices]
        lower_bounds = archive.lower_bounds[bc_indices]
        upper_bounds = archive.upper_bounds[bc_indices]

    host_ax = plt.gca() if ax is None else ax  # Try to get current axis.
    df = archive.as_pandas(include_solutions=False)
    if sort_archive:
        df.sort_values(by=['objective'], inplace=True)
    vmin = np.min(df['objective']) if vmin is None else vmin
    vmax = np.max(df['objective']) if vmax is None else vmax
    norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
    objectives = df['objective'].to_numpy()
    ys = df[cols].to_numpy()
    y_ranges = upper_bounds - lower_bounds

    # Transform all data to be in the first axis coordinates.
    normalized_ys = np.zeros_like(ys)
    normalized_ys[:, 0] = ys[:, 0]
    normalized_ys[:, 1:] = (
        (ys[:, 1:] - lower_bounds[1:]) / y_ranges[1:] * y_ranges[0] +
        lower_bounds[0])

    # Copy the axis for the other bcs.
    axes = [host_ax] + [host_ax.twinx() for i in range(len(cols) - 1)]
    for i, axis in enumerate(axes):
        axis.set_ylim(lower_bounds[i], upper_bounds[i])
        axis.spines['top'].set_visible(False)
        axis.spines['bottom'].set_visible(False)
        if axis != host_ax:
            axis.spines['left'].set_visible(False)
            axis.yaxis.set_ticks_position('right')
            axis.spines["right"].set_position(("axes", i / (len(cols) - 1)))

    host_ax.set_xlim(0, len(cols) - 1)
    host_ax.set_xticks(range(len(cols)))
    host_ax.set_xticklabels(axis_labels)
    host_ax.tick_params(axis='x', which='major', pad=7)
    host_ax.spines['right'].set_visible(False)
    host_ax.xaxis.tick_top()

    for archive_entry, objective in zip(normalized_ys, objectives):
        # Draw straight lines between the axes in the appropriate color.
        color = cmap(norm(objective))
        host_ax.plot(range(len(cols)),
                     archive_entry,
                     c=color,
                     alpha=alpha,
                     linewidth=linewidth)

    # Create a colorbar.
    mappable = ScalarMappable(cmap=cmap)
    mappable.set_clim(vmin, vmax)
    host_ax.figure.colorbar(mappable,
                            pad=cbar_pad,
                            orientation=cbar_orientation)
Exemple #52
0
def trajectory(
    adata: AnnData,
    basis: str = "umap",
    root_milestone: Union[str, None] = None,
    milestones: Union[str, None] = None,
    color_seg: str = "t",
    cmap_seg: str = "viridis",
    layer_seg: Union[str, None] = "fitted",
    perc_seg: Union[List, None] = None,
    color_cells: Union[str, None] = None,
    scale_path: float = 1,
    arrows: bool = False,
    arrow_offset: int = 10,
    show_info: bool = True,
    ax=None,
    show: Optional[bool] = None,
    save: Union[str, bool, None] = None,
    **kwargs,
):
    """\
    Project trajectory onto embedding.

    Parameters
    ----------
    adata
        Annotated data matrix.
    basis
        Name of the `obsm` basis to use.
    root_milestone
        tip defining progenitor branch.
    milestones
        tips defining the progenies branches.
    col_seg
        color trajectory segments.
    layer_seg
        layer to use when coloring seg with a feature.
    perc_seg
        percentile cutoffs for segments.
    color_cells
        cells color.
    scale_path
        changes the width of the path
    arrows
        display arrows on segments (positioned at half pseudotime distance).
    arrow_offset
        arrow offset in number of nodes used to obtain its direction.
    show_info
        display legend/colorbar.
    ax
        Add plot to existing ax
    show
        show the plot.
    save
        save the plot.
    kwargs
        arguments to pass to :func:`scanpy.pl.embedding`.

    Returns
    -------
    If `show==False` a :class:`~matplotlib.axes.Axes`

    """
    class GC(GraphicsContextBase):
        def __init__(self):
            super().__init__()
            self._capstyle = "round"

    def custom_new_gc(self):
        return GC()

    RendererBase.new_gc = types.MethodType(custom_new_gc, RendererBase)

    if "graph" not in adata.uns:
        raise ValueError(
            "You need to run `tl.pseudotime` first before plotting.")

    graph = adata.uns["graph"]

    emb = adata.obsm[f"X_{basis}"]
    emb_f = adata[graph["cells_fitted"], :].obsm[f"X_{basis}"]

    if "components" in kwargs:
        cmp = np.array(kwargs["components"]) - 1
        emb_f = emb_f[:, cmp]

    else:
        emb_f = emb_f[:, :2]

    R = graph["R"]

    nodes = graph["pp_info"].index
    proj = pd.DataFrame((np.dot(emb_f.T, R) / R.sum(axis=0)).T, index=nodes)

    B = graph["B"]
    g = igraph.Graph.Adjacency((B > 0).tolist(), mode="undirected")
    g.vs[:]["name"] = [v.index for v in g.vs]

    miles_ids = np.concatenate([graph["tips"], graph["forks"]])

    if root_milestone is not None:
        dct = graph["milestones"]
        nodes = g.get_all_shortest_paths(dct[root_milestone],
                                         [dct[m] for m in milestones])
        nodes = np.unique(np.concatenate(nodes))
        tips = graph["tips"][np.isin(graph["tips"], nodes)]
        proj = proj.loc[nodes, :]
        g.delete_vertices(
            graph["pp_info"].index[~graph["pp_info"].index.isin(nodes)])
        if ax is None:
            ax = sc.pl.embedding(adata, show=False, basis=basis, **kwargs)
        else:
            sc.pl.embedding(adata, show=False, ax=ax, basis=basis, **kwargs)

    c_edges = np.array([e.split("|") for e in adata.obs.edge], dtype=int)
    cells = [any(np.isin(c_e, nodes)) for c_e in c_edges]

    import logging

    anndata_logger = logging.getLogger("anndata")
    prelog = anndata_logger.level
    anndata_logger.level = 40
    adata_c = adata[cells, :]

    if is_categorical(adata, color_cells):
        if (color_cells + "_colors" not in adata.uns
                or len(adata.uns[color_cells + "_colors"]) == 1):
            from . import palette_tools

            palette_tools._set_default_colors_for_categorical_obs(
                adata, color_cells)

        adata_c.uns[color_cells + "_colors"] = [
            adata.uns[color_cells + "_colors"][np.argwhere(
                adata.obs[color_cells].cat.categories == c)[0][0]]
            for c in adata_c.obs[color_cells].cat.categories
        ]

    if ax is None:
        ax = sc.pl.embedding(adata,
                             color=color_cells,
                             basis=basis,
                             show=False,
                             **kwargs)
    else:
        sc.pl.embedding(adata,
                        color=color_cells,
                        basis=basis,
                        ax=ax,
                        show=False,
                        **kwargs)

    anndata_logger.level = prelog
    if show_info == False and color_cells is not None:
        if is_categorical(adata, color_cells):
            if ax.get_legend() is not None:
                ax.get_legend().remove()
        else:
            ax.set_box_aspect(aspect=1)
            fig = ax.get_gridspec().figure
            cbar = np.argwhere(
                ["colorbar" in a.get_label() for a in fig.get_axes()]).ravel()
            if len(cbar) > 0:
                fig.get_axes()[cbar[0]].remove()

    al = np.array(g.get_edgelist())

    edges = [g.vs[e.tolist()]["name"] for e in al]

    lines = [[tuple(proj.loc[j]) for j in i] for i in edges]

    vals = pd.Series(_get_color_values(adata, color_seg, layer=layer_seg)[0],
                     index=adata.obs_names)
    R = pd.DataFrame(adata.uns["graph"]["R"],
                     index=adata.uns["graph"]["cells_fitted"])
    R = R.loc[adata.obs_names]
    vals = vals[~np.isnan(vals)]
    R = R.loc[vals.index]

    def get_nval(i):
        return np.average(vals, weights=R.loc[:, i])

    node_vals = np.array(list(map(get_nval, range(R.shape[1]))))
    seg_val = node_vals[np.array(edges)].mean(axis=1)

    if perc_seg is not None:
        min_v, max_v = np.percentile(seg_val, perc_seg)
        seg_val[seg_val < min_v] = min_v
        seg_val[seg_val > max_v] = max_v

    sm = ScalarMappable(norm=Normalize(vmin=seg_val.min(), vmax=seg_val.max()),
                        cmap=cmap_seg)

    lines = [lines[i] for i in np.argsort(seg_val)]
    seg_val = seg_val[np.argsort(seg_val)]

    lc = matplotlib.collections.LineCollection(lines,
                                               colors="k",
                                               linewidths=7.5 * scale_path,
                                               zorder=100)
    ax.add_collection(lc)

    g = igraph.Graph.Adjacency((B > 0).tolist(), mode="undirected")
    seg = graph["pp_seg"].loc[:, ["from", "to"]].values.tolist()

    if arrows:
        for s in seg:
            path = np.array(g.get_shortest_paths(s[0], s[1])[0])
            coord = proj.loc[path, ].values
            out = np.empty(len(path) - 1)
            cdist_numba(coord, out)
            mid = np.argmin(np.abs(out.cumsum() - out.sum() / 2))
            if mid + arrow_offset > (len(path) - 1):
                arrow_offset = len(path) - 1 - mid
            ax.quiver(
                proj.loc[path[mid], 0],
                proj.loc[path[mid], 1],
                proj.loc[path[mid + arrow_offset], 0] - proj.loc[path[mid], 0],
                proj.loc[path[mid + arrow_offset], 1] - proj.loc[path[mid], 1],
                headwidth=15 * scale_path,
                headaxislength=10 * scale_path,
                headlength=10 * scale_path,
                units="dots",
                zorder=101,
            )
            c_arrow = vals[(np.array([e.split("|") for e in adata.obs.edge],
                                     dtype=int) == path[mid]).sum(
                                         axis=1) == 1].mean()
            ax.quiver(
                proj.loc[path[mid], 0],
                proj.loc[path[mid], 1],
                proj.loc[path[mid + arrow_offset], 0] - proj.loc[path[mid], 0],
                proj.loc[path[mid + arrow_offset], 1] - proj.loc[path[mid], 1],
                headwidth=12 * scale_path,
                headaxislength=10 * scale_path,
                headlength=10 * scale_path,
                units="dots",
                color=sm.to_rgba(c_arrow),
                zorder=102,
            )

    lc = matplotlib.collections.LineCollection(
        lines,
        colors=[sm.to_rgba(sv) for sv in seg_val],
        linewidths=5 * scale_path,
        zorder=104,
    )

    miles_ids = miles_ids[np.isin(miles_ids, proj.index)]

    ax.scatter(
        proj.loc[miles_ids, 0],
        proj.loc[miles_ids, 1],
        zorder=103,
        c="k",
        s=200 * scale_path,
    )
    ax.add_collection(lc)

    tip_val = node_vals[miles_ids]

    ax.scatter(
        proj.loc[miles_ids, 0],
        proj.loc[miles_ids, 1],
        zorder=105,
        c=sm.to_rgba(tip_val),
        s=140 * scale_path,
    )

    if show == False:
        return ax

    savefig_or_show("trajectory", show=show, save=save)
Exemple #53
0
def cvt_archive_heatmap(archive,
                        ax=None,
                        plot_centroids=True,
                        plot_samples=False,
                        transpose_bcs=False,
                        cmap="magma",
                        square=False,
                        ms=1,
                        lw=0.5,
                        vmin=None,
                        vmax=None):
    """Plots heatmap of a :class:`~ribs.archives.CVTArchive` with 2D behavior
    space.

    Essentially, we create a Voronoi diagram and shade in each cell with a
    color corresponding to the objective value of that cell's elite.

    Depending on how many bins are in the archive, ``ms`` and ``lw`` may need to
    be tuned. If there are too many bins, the Voronoi diagram and centroid
    markers will make the entire image appear black. In that case, try turning
    off the centroids with ``plot_centroids=False`` or even removing the lines
    completely with ``lw=0``.

    Examples:

        .. plot::
            :context: close-figs

            >>> import numpy as np
            >>> import matplotlib.pyplot as plt
            >>> from ribs.archives import CVTArchive
            >>> from ribs.visualize import cvt_archive_heatmap
            >>> # Populate the archive with the negative sphere function.
            >>> archive = CVTArchive(100, [(-1, 1), (-1, 1)])
            >>> archive.initialize(solution_dim=2)
            >>> for x in np.linspace(-1, 1, 100):
            ...     for y in np.linspace(-1, 1, 100):
            ...         archive.add(solution=np.array([x,y]),
            ...                     objective_value=-(x**2 + y**2),
            ...                     behavior_values=np.array([x,y]))
            >>> # Plot a heatmap of the archive.
            >>> plt.figure(figsize=(8, 6))
            >>> cvt_archive_heatmap(archive)
            >>> plt.title("Negative sphere function")
            >>> plt.xlabel("x coords")
            >>> plt.ylabel("y coords")
            >>> plt.show()

    Args:
        archive (CVTArchive): A 2D CVTArchive.
        ax (matplotlib.axes.Axes): Axes on which to plot the heatmap. If None,
            the current axis will be used.
        plot_centroids (bool): Whether to plot the cluster centroids.
        plot_samples (bool): Whether to plot the samples used when generating
            the clusters.
        transpose_bcs (bool): By default, the first BC in the archive will
            appear along the x-axis, and the second will be along the y-axis. To
            switch this (i.e. to transpose the axes), set this to True.
        cmap (str, list, matplotlib.colors.Colormap): Colormap to use when
            plotting intensity. Either the name of a colormap, a list of RGB or
            RGBA colors (i.e. an Nx3 or Nx4 array), or a colormap object.
        square (bool): If True, set the axes aspect ratio to be "equal".
        ms (float): Marker size for both centroids and samples.
        lw (float): Line width when plotting the voronoi diagram.
        vmin (float): Minimum objective value to use in the plot. If None, the
            minimum objective value in the archive is used.
        vmax (float): Maximum objective value to use in the plot. If None, the
            maximum objective value in the archive is used.
    Raises:
        ValueError: The archive is not 2D.
    """
    # pylint: disable = too-many-locals

    if archive.behavior_dim != 2:
        raise ValueError("Cannot plot heatmap for non-2D archive.")

    # Try getting the colormap early in case it fails.
    cmap = _retrieve_cmap(cmap)

    # Retrieve data from archive.
    lower_bounds = archive.lower_bounds
    upper_bounds = archive.upper_bounds
    centroids = archive.centroids
    samples = archive.samples
    if transpose_bcs:
        lower_bounds = np.flip(lower_bounds)
        upper_bounds = np.flip(upper_bounds)
        centroids = np.flip(centroids, axis=1)
        samples = np.flip(samples, axis=1)

    # Retrieve and initialize the axis.
    ax = plt.gca() if ax is None else ax
    ax.set_xlim(lower_bounds[0], upper_bounds[0])
    ax.set_ylim(lower_bounds[1], upper_bounds[1])
    if square:
        ax.set_aspect("equal")

    # Add faraway points so that the edge regions of the Voronoi diagram are
    # filled in. Refer to
    # https://stackoverflow.com/questions/20515554/colorize-voronoi-diagram
    # for more info.
    interval = upper_bounds - lower_bounds
    scale = 1000
    faraway_pts = [
        upper_bounds + interval * scale,  # Far upper right.
        upper_bounds + interval * [-1, 1] * scale,  # Far upper left.
        lower_bounds + interval * [-1, -1] * scale,  # Far bottom left.
        lower_bounds + interval * [1, -1] * scale,  # Far bottom right.
    ]
    vor = Voronoi(np.append(centroids, faraway_pts, axis=0))

    # Calculate objective value for each region. `vor.point_region` contains
    # the region index of each point.
    region_obj = [None] * len(vor.regions)
    min_obj, max_obj = np.inf, -np.inf
    pt_to_obj = _get_pt_to_obj(archive)
    for pt_idx, region_idx in enumerate(
            vor.point_region[:-4]):  # Exclude faraway_pts.
        if region_idx != -1 and pt_idx in pt_to_obj:
            obj = pt_to_obj[pt_idx]
            min_obj = min(min_obj, obj)
            max_obj = max(max_obj, obj)
            region_obj[region_idx] = obj

    # Override objective value range.
    min_obj = min_obj if vmin is None else vmin
    max_obj = max_obj if vmax is None else vmax

    # Shade the regions.
    for region, objective in zip(vor.regions, region_obj):
        # This check is O(n), but n is typically small, and creating
        # `polygon` is also O(n) anyway.
        if -1 not in region:
            if objective is None:
                color = "white"
            else:
                normalized_obj = np.clip(
                    (objective - min_obj) / (max_obj - min_obj), 0.0, 1.0)
                color = cmap(normalized_obj)
            polygon = [vor.vertices[i] for i in region]
            ax.fill(*zip(*polygon), color=color, ec="k", lw=lw)

    # Create a colorbar.
    mappable = ScalarMappable(cmap=cmap)
    mappable.set_clim(min_obj, max_obj)
    ax.figure.colorbar(mappable, ax=ax, pad=0.1)

    # Plot the sample points and centroids.
    if plot_samples:
        ax.plot(samples[:, 0], samples[:, 1], "o", c="gray", ms=ms)
    if plot_centroids:
        ax.plot(centroids[:, 0], centroids[:, 1], "ko", ms=ms)
Exemple #54
0
def plot_roi_summary(roi_pac_all,
                     reject,
                     xaxis,
                     marker_size,
                     eventnames,
                     lobe_bounds,
                     lobe_names,
                     ordered_rois,
                     yl_factor=1,
                     ylim=None):
    # roi_pac_all:  (1 x Nlevels x Nrois) or ( Nsubjects x Nlevels x Nrois)
    # reject: (1 x Nlevels x Nrois) or ( Nsubjects x Nlevels x Nrois)
    # xaxis: (1 x Nrois) or (Nsubjects x Nrois)

    Nlevels = len(eventnames)

    if ylim is None:
        ylim = np.nanmax(np.abs(roi_pac_all)) * 1.1 * np.array([-1, 1])
    cmap = plt.get_cmap('seismic')
    norm = Normalize(vmin=ylim[0], vmax=ylim[1])
    scalarMap = ScalarMappable(norm=norm, cmap=cmap)

    fig = plt.figure()
    fig.clf()
    fig.set_size_inches([17, 9.5])
    gs = gridspec.GridSpec(Nlevels + 1,
                           1,
                           height_ratios=[yl_factor] * (Nlevels - 1) +
                           [1, 0.5],
                           hspace=0)
    axs = []
    for level_i, level_name in enumerate(eventnames):
        ax = fig.add_subplot(gs[level_i])
        toplot_level = roi_pac_all[:, level_i, :]
        mask_level = reject[:, level_i, :]

        ax.scatter(xaxis[mask_level],
                   toplot_level[mask_level],
                   s=marker_size,
                   c=toplot_level[mask_level],
                   cmap=cmap,
                   norm=norm,
                   marker='o',
                   edgecolors='k',
                   zorder=3)
        ax.scatter(xaxis[np.logical_not(mask_level)],
                   toplot_level[np.logical_not(mask_level)],
                   s=marker_size,
                   c='k',
                   cmap=cmap,
                   norm=norm,
                   marker='+',
                   zorder=3)
        ax.axhline(0, color='k')
        for lobe, next in zip(lobe_bounds[:-1], lobe_bounds[1:]):
            ax.axvline((lobe[1] + next[0]) / 2, color='k')

        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

        ax.set_xticks(xaxis[0, :])
        if level_i == Nlevels - 1:
            ax.set_xticklabels(ordered_rois, rotation='vertical')
        else:
            ax.set_xticklabels('')
        hdiff = (xaxis[0, 1] - xaxis[0, 0])
        ax.set_xlim([xaxis[0, 0] - hdiff, xaxis[0, -1] + hdiff])

        ax.set_ylabel(level_name)
        if level_i < Nlevels - 1:
            ax.set_ylim(ylim * yl_factor)
        else:
            ax.set_ylim(ylim)

        if level_i == 0:
            for lobe, name in zip(lobe_bounds, lobe_names):
                ax.text(np.mean(lobe),
                        ax.get_ylim()[1] * 1.1,
                        name,
                        horizontalalignment='center')

        axs.append(ax)

    return fig, axs, gs
Exemple #55
0
import sys
import numpy as np

#maps = ["Accent","Accent_r","Blues","Blues_r","BrBG","BrBG_r","BuGn","BuGn_r","BuPu","BuPu_r","CMRmap","CMRmap_r","Dark2","Dark2_r","GnBu","GnBu_r","Greens","Greens_r","Greys","Greys_r","OrRd","OrRd_r","Oranges","Oranges_r","PRGn","PRGn_r","Paired","Paired_r","Pastel1","Pastel1_r","Pastel2","Pastel2_r","PiYG","PiYG_r","PuBu","PuBuGn","PuBuGn_r","PuBu_r","PuOr","PuOr_r","PuRd","PuRd_r","Purples","Purples_r","RdBu","RdBu_r","RdGy","RdGy_r","RdPu","RdPu_r","RdYlBu","RdYlBu_r","RdYlGn","RdYlGn_r","Reds","Reds_r","Set1","Set1_r","Set2","Set2_r","Set3","Set3_r","Spectral","Spectral_r","Vega10","Vega10_r","Vega20","Vega20_r","Vega20b","Vega20b_r","Vega20c","Vega20c_r","Wistia","Wistia_r","YlGn","YlGnBu","YlGnBu_r","YlGn_r","YlOrBr","YlOrBr_r","YlOrRd","YlOrRd_r","afmhot","afmhot_r","autumn","autumn_r","binary","binary_r","bone","bone_r","brg","brg_r","bwr","bwr_r","cool","cool_r","coolwarm","coolwarm_r","copper","copper_r","cubehelix","cubehelix_r","flag","flag_r","gist_earth","gist_earth_r","gist_gray","gist_gray_r","gist_heat","gist_heat_r","gist_ncar","gist_ncar_r","gist_rainbow","gist_rainbow_r","gist_stern","gist_stern_r","gist_yarg","gist_yarg_r","gnuplot","gnuplot2","gnuplot2_r","gnuplot_r","gray","gray_r","hot","hot_r","hsv","hsv_r","inferno","inferno_r","jet","jet_r","magma","magma_r","nipy_spectral","nipy_spectral_r","ocean","ocean_r","pink","pink_r","plasma","plasma_r","prism","prism_r","rainbow","rainbow_r","seismic","seismic_r","spectral","spectral_r","spring","spring_r","summer","summer_r","terrain","terrain_r","viridis","viridis_r","winter","winter_r"]
maps = ["afmhot","CMRmap","flag","cubehelix","Blues_r","bone","BrBG","BrBG_r","BuGn_r","binary_r","coolwarm","Spectral","Spectral_r","terrain","gist_stern","gnuplot","hot","inferno","plasma","prism","Reds","seismic"]

for cmp in maps:
	# Color mapping
	from matplotlib.cm import ScalarMappable
	
	mycmap = ScalarMappable(cmap=cmp)
	#mycmap.to_rgba([0,1])

	from PIL import Image, ImageDraw
	image_grey = Image.open(sys.argv[1])
	image_grey.load()
	image_grey = image_grey.convert('L')

	image = Image.fromarray(np.uint8(np.array(mycmap.to_rgba(image_grey.getdata())).reshape(*(image_grey.size),4)*255))		

	filename = sys.argv[1].split('.')[0]+'_color_'+cmp+'.bmp'
	print('Saving', filename)
	image.save(filename, "BMP")

	'''image = Image.new("RGB", image_grey.size, (0,0,0))
	draw = ImageDraw.Draw(image)
	for i in range(image_grey.size[0]):
		for j in range(image_grey.size[0]):
			c = image_grey.getpixel((i, j))[0]
			clr = mycmap.to_rgba(c/255)
			draw.point((i, j), (int(255*clr[0]), int(255*clr[1]), int(255*clr[2])))'''
Exemple #56
0
def plot_projection_surfaces(projections,
                             eventnames,
                             vertices,
                             subject,
                             subjects_dir,
                             savename,
                             Nprojs=3,
                             blnOffscreen=False,
                             savepath=None):
    '''
    Plot the source space projections of the frequency profiles

    See Figure 3 in Stephen et al 2019

    :param projections: ndarray, (components,levels,sources)
        The projections to plot, e.g. from spatial_phase_amplitude_coupling.compute_source_space_projections
    :param eventnames: list of basestring
        List of names of the levels
    :param vertices: list of numeric
        the vertices corresponding to the sources in projections
    :param subject: basestr
        The subject ID
    :param subjects_dir:
        The subject directory
    :param savename:
        The name of the file to save to
    :param Nprojs: int
        The number of projections to plot
    :param blnOffscreen: bool
        Whether to plot the surfaces offscreen
    :param savepath: basestr | None
        The path to save to (or None for no save)
    :return:
    '''
    def label_func(f):
        return eventnames[int(f)]

    for proji in range(Nprojs):
        toplot = projections[proji, :, :]
        toplot[np.isnan(toplot)] = 0

        yl_diff = np.array([-1, 1]) * np.percentile(
            np.abs(toplot[~np.isnan(toplot)]), 100)
        ctrl_pts = [yl_diff[0], 0, yl_diff[1]]
        stc_proj = mne.SourceEstimate(toplot.T,
                                      vertices=vertices,
                                      tmin=0,
                                      tstep=1,
                                      subject=subject)
        brain = plot_surf(stc_proj, {
            'kind': 'value',
            'lims': ctrl_pts
        },
                          'seismic',
                          label_func,
                          'semi-inflated',
                          False,
                          blnOffscreen,
                          subjects_dir=subjects_dir,
                          blnMaskUnknown=True)

        cmap = plt.get_cmap('seismic')
        norm = Normalize(vmin=ctrl_pts[0], vmax=ctrl_pts[-1])
        scalarMap = ScalarMappable(norm=norm, cmap=cmap)
        scalarMap.set_array(np.linspace(ctrl_pts[0], ctrl_pts[-1], 100))

        if savepath is not None:
            fname = savename + '_FreqProfs_surfProj{}'.format(proji + 1)
            for j, (t, level) in enumerate(zip(stc_proj.times, eventnames)):
                brain.set_time(t)
                save_surf(brain, fname, savepath, '_{}{}'.format(j, level))
            brain.close()
Exemple #57
0
def multiscatter(deamid_mat,
                 key=None,
                 type=None,
                 hts=None,
                 l=None,
                 t=None,
                 path=None,
                 base_name=None,
                 low_counts=None,
                 **kwargs):
    """
    Creates tripetides multiscatter plot
    Input:
        - key: Ydata entry info to include in the plot
        - l: index of layer to plot
        - type: wether key is categorical ('cat') or continuous ('cont')
        - path: path in which plot is saved
    *Notes:
        - key can refer to numeric but still categorical data
    """

    if l is not None:
        Ddata = deamid_mat.layers[l]
    else:
        Ddata = deamid_mat.D
    counts = deamid_mat.counts
    Ydata = deamid_mat.Ydata
    trps_data = deamid_mat.trps_data

    if key is not None:
        Yvect = Ydata[key]
    else:
        Yvect = np.zeros(Ddata.shape[0])

    if len(hts) == 0:
        exists_hts = False
    else:
        exists_hts = True

    # Detect samples with known Ydata
    if type == 'cat':
        known = Yvect != 'U'
    elif type == 'cont':
        known = Yvect != -1
    else:
        known = np.array([True for i in range(len(Yvect))])

    # Apply transformation to continuous data in order to visualize better
    if type == 'cont' and t != None:
        Yvect = transform(Yvect, t)

    # Include visual options from vopts to defaults
    dfts = {
        'fontsize': 12,
        'fontpos': [0.15, 0.35],
        'fontweight': 'medium',
        'mcolor': 'black',
        'msize': 5,
        'reg': None,
        'cat_cmap': 'tab10',
        'bbox_to_anchor': (0.5, 2.2)
    }
    dfts.update(kwargs)

    # Create mapping for either continuous or categorical data
    if type == 'cat':
        Yset = np.sort(list(set(Yvect)))  # Set (array type) of Y values
        keyCm = plt.get_cmap(dfts['cat_cmap'])
        keyNorm = plt.Normalize(vmin=0, vmax=len(Yset))
        keySm = ScalarMappable(norm=keyNorm, cmap=keyCm)
        keySm.set_array([])
        # Handles for the legned
        handles = [
            plt.plot([], [],
                     markersize=18,
                     marker='.',
                     ls='none',
                     c=keySm.to_rgba(v))[0] for v in range(len(Yset))
        ]
        plt.close()

        map_color = np.zeros(Yvect.shape[0])
        i = 0
        for Yval in Yset:
            map_color[Yvect == Yval] = i
            i += 1
        Yvect = map_color
    elif type == 'cont':
        # Create mappable for colorbar for ages
        keyCm = plt.get_cmap("cool")
        keyNorm = plt.Normalize(np.min(Yvect[known]), np.max(Yvect[known]))
        keySm = ScalarMappable(norm=keyNorm, cmap=keyCm)
        keySm.set_array([])
    else:
        keyCm = plt.get_cmap("brg")
        keyNorm = plt.Normalize(0, 1)
        keySm = ScalarMappable(norm=keyNorm, cmap=keyCm)
        keySm.set_array([])

    # Create map for halftime
    if exists_hts == True:
        htcm = plt.get_cmap("coolwarm")
        htnorm = plt.Normalize(np.min(hts), np.max(hts))
        htsm = ScalarMappable(norm=htnorm, cmap=htcm)
        htsm.set_array([])

    num_dims = Ddata.shape[1]
    fig = plt.figure(figsize=(16, 12))
    axes = [[False for i in range(num_dims)] for j in range(num_dims)]
    n = 1
    if dfts['reg'] == 'linear':
        reg = linear_model.BayesianRidge(fit_intercept=False)
    for i in range(num_dims):
        for j in range(num_dims):
            ax = fig.add_subplot(num_dims, num_dims, n)
            # Extract j and i columns
            plot_data = Ddata[:, [j, i]]
            # Check for pairwise low counts data rowwise
            if low_counts is not None:
                counts_data = counts[:, [j, i]]
                not_miss = counts_data > low_counts
            else:
                not_miss = ~np.isnan(plot_data)
            sele = np.sum(not_miss, 1) == 2
            sele = np.array(sele)

            X = plot_data[:, 0]
            Y = plot_data[:, 1]
            # Yplot = Yvect[sele]
            sele_known = np.logical_and(sele, known)
            sele_unknown = np.logical_and(sele, ~known)
            if i != j:
                # Fit linear regression
                if X.shape[0] > 1 and dfts['reg'] is not None:
                    x_pred = np.linspace(np.nanmin(X), np.nanmax(X), 500)
                    if dfts['reg'] == 'linear':
                        # Linear 1to1
                        reg.fit(X[sele].reshape(-1, 1), Y[sele])
                        y_pred = reg.predict(x_pred.reshape(-1, 1))
                        pl = ax.plot(x_pred.reshape(-1, 1),
                                     y_pred,
                                     color='lime',
                                     linewidth=1)
                        # pl = ax.plot([0,1], [0,1], color='lime', linewidth=0.6, alpha=0.8)
                    else:
                        # Non-linear regression 1 parameter a
                        a0 = 1
                        res1 = least_squares(lsf1par,
                                             a0,
                                             loss='soft_l1',
                                             f_scale=0.1,
                                             args=(X, Y))
                        a = res1.x
                        y_pred = np.exp(a * (x_pred - 1))
                        pl = ax.plot(x_pred,
                                     y_pred,
                                     color='black',
                                     linewidth=1)

                        # Non-linear regression 1 parameter a
                        a0 = 1
                        res2 = least_squares(lsf1par_inv,
                                             a0,
                                             loss='soft_l1',
                                             f_scale=0.1,
                                             args=(X, Y))
                        a = res2.x
                        x_pred1 = np.linspace(0.0001, 1.1, 500)
                        y_pred = np.log(x_pred1) / a + 1
                        pl = ax.plot(x_pred1, y_pred, color='red', linewidth=1)

                if type != None:
                    sc = ax.scatter(X[sele_known],
                                    Y[sele_known],
                                    s=dfts['msize'],
                                    alpha=0.7,
                                    c=keySm.to_rgba(Yvect[sele_known]))
                    sc = ax.scatter(X[sele_unknown],
                                    Y[sele_unknown],
                                    s=dfts['msize'],
                                    alpha=0.7,
                                    c='black')
                else:
                    # rel = X/Y
                    # rel = np.logical_and(rel>0.8, rel<1.2)
                    # rel = rel*1
                    # sc = ax.scatter(X, Y, s=dfts['msize'], c=keySm.to_rgba(rel), alpha=0.8)
                    sc = ax.scatter(X[sele_known],
                                    Y[sele_known],
                                    s=dfts['msize'],
                                    c=dfts['mcolor'],
                                    alpha=0.8)
            elif i == j:
                txt = trps_data[i]['tripep']
                if trps_data[i][1] != 'NA':
                    txt = trps_data[i][0] + '\n' + txt

                if trps_data[i][4] != 0:
                    txt = txt + str(trps_data[i][4])
                elif trps_data[i][3] != 0:
                    txt = txt + str(trps_data[i][3])
                ax.text(dfts['fontpos'][0],
                        dfts['fontpos'][1],
                        txt,
                        fontweight=dfts['fontweight'],
                        fontsize=dfts['fontsize'],
                        transform=ax.transAxes)
                if exists_hts == False:
                    ax.set_facecolor('white')
                elif exists_hts == True and hts[i] != -1:
                    ax.set_facecolor(htsm.to_rgba(hts[i]))
                else:
                    ax.set_facecolor('white')
                ax.axis(xmin=np.nanmin(X),
                        xmax=np.nanmax(X),
                        ymin=np.nanmin(Y),
                        ymax=np.nanmax(Y))
            # Set axis labels:
            ax.set_xlabel(trps_data['tripep'][j])
            ax.set_ylabel(trps_data['tripep'][i])
            # Equal scale axis
            # ax.axis(xmin=-0.1, xmax=1.1, ymin=-0.1, ymax=1.1)
            # ax.set_aspect(aspect=1)
            # Hide axes for all but the plots on the edge:
            if i < num_dims - 1:
                ax.xaxis.set_visible(False)
            if j > 0:
                ax.yaxis.set_visible(False)
            # Add this axis to the list.
            axes[j][i] = ax
            n += 1
    axes = np.array(axes)
    plt.subplots_adjust(left=0.1,
                        right=0.9,
                        top=0.85,
                        bottom=0.1,
                        wspace=0.2,
                        hspace=0.2)
    # halftime color bar
    if exists_hts != False:
        htcbar_ax = fig.add_axes([0.92, 0.51, 0.015, 0.38])
        htcbar = plt.colorbar(htsm, cax=htcbar_ax, orientation="vertical")
        htcbar.ax.set_title("Reaction\nconstant")

    if type == 'cont':  # Make a colorbar for it
        keycb_ax = fig.add_axes([0.92, 0.1, 0.015, 0.38])
        keycbar = plt.colorbar(keysm, cax=keycb_ax, orientation="vertical")
        if t != None:
            keycbar.ax.set_title(key + ' ({})'.format(t))
        else:
            keycbar.ax.set_title(key)
    elif type == 'cat':  # Set the legend for it
        # ax_pos =int(np.floor(num_dims/2))
        # axes[ax_pos,0].legend(handles, Yset,ncol=3,
        #                  bbox_to_anchor=dfts['bbox_to_anchor'],
        #                  fontsize='large')
        leg = plt.legend(handles,
                         Yset,
                         ncol=3,
                         bbox_to_anchor=dfts['bbox_to_anchor'],
                         fontsize='large')
        export_legend(leg, key)
    # Show or save
    if path is None and base_name is None:
        plt.show()
    else:
        plt.savefig(path + base_name + '_multiscatter.png',
                    dpi=300,
                    format=None)
        plt.close()
Exemple #58
0
    def _get_2d_plot(self,
                     label_stable=True,
                     label_unstable=True,
                     ordering=None,
                     energy_colormap=None,
                     vmin_mev=-60.0,
                     vmax_mev=60.0,
                     show_colorbar=True,
                     process_attributes=False):
        """
        Shows the plot using pylab.  Usually I won't do imports in methods,
        but since plotting is a fairly expensive library to load and not all
        machines have matplotlib installed, I have done it this way.
        """

        plt = get_publication_quality_plot(8, 6)
        from matplotlib.font_manager import FontProperties
        if ordering is None:
            (lines, labels, unstable) = self.pd_plot_data
        else:
            (_lines, _labels, _unstable) = self.pd_plot_data
            (lines, labels,
             unstable) = order_phase_diagram(_lines, _labels, _unstable,
                                             ordering)
        if energy_colormap is None:
            if process_attributes:
                for x, y in lines:
                    plt.plot(x, y, "k-", linewidth=3, markeredgecolor="k")
                # One should think about a clever way to have "complex"
                # attributes with complex processing options but with a clear
                #  logic. At this moment, I just use the attributes to know
                # whether an entry is a new compound or an existing (from the
                #  ICSD or from the MP) one.
                for x, y in labels.iterkeys():
                    if labels[(x, y)].attribute is None or \
                            labels[(x, y)].attribute == "existing":
                        plt.plot(x,
                                 y,
                                 "ko",
                                 linewidth=3,
                                 markeredgecolor="k",
                                 markerfacecolor="b",
                                 markersize=12)
                    else:
                        plt.plot(x,
                                 y,
                                 "k*",
                                 linewidth=3,
                                 markeredgecolor="k",
                                 markerfacecolor="g",
                                 markersize=18)
            else:
                for x, y in lines:
                    plt.plot(x,
                             y,
                             "ko-",
                             linewidth=3,
                             markeredgecolor="k",
                             markerfacecolor="b",
                             markersize=15)
        else:
            from matplotlib.colors import Normalize, LinearSegmentedColormap
            from matplotlib.cm import ScalarMappable
            pda = PDAnalyzer(self._pd)
            for x, y in lines:
                plt.plot(x, y, "k-", linewidth=3, markeredgecolor="k")
            vmin = vmin_mev / 1000.0
            vmax = vmax_mev / 1000.0
            if energy_colormap == 'default':
                mid = -vmin / (vmax - vmin)
                cmap = LinearSegmentedColormap.from_list(
                    'my_colormap', [(0.0, '#005500'), (mid, '#55FF55'),
                                    (mid, '#FFAAAA'), (1.0, '#FF0000')])
            else:
                cmap = energy_colormap
            norm = Normalize(vmin=vmin, vmax=vmax)
            _map = ScalarMappable(norm=norm, cmap=cmap)
            _energies = [
                pda.get_equilibrium_reaction_energy(entry)
                for coord, entry in labels.iteritems()
            ]
            energies = [en if en < 0.0 else -0.00000001 for en in _energies]
            vals_stable = _map.to_rgba(energies)
            ii = 0
            if process_attributes:
                for x, y in labels.iterkeys():
                    if labels[(x, y)].attribute is None or \
                            labels[(x, y)].attribute == "existing":
                        plt.plot(x,
                                 y,
                                 "o",
                                 markerfacecolor=vals_stable[ii],
                                 markersize=12)
                    else:
                        plt.plot(x,
                                 y,
                                 "*",
                                 markerfacecolor=vals_stable[ii],
                                 markersize=18)
                    ii += 1
            else:
                for x, y in labels.iterkeys():
                    plt.plot(x,
                             y,
                             "o",
                             markerfacecolor=vals_stable[ii],
                             markersize=15)
                    ii += 1

        font = FontProperties()
        font.set_weight("bold")
        font.set_size(24)

        # Sets a nice layout depending on the type of PD. Also defines a
        # "center" for the PD, which then allows the annotations to be spread
        # out in a nice manner.
        if len(self._pd.elements) == 3:
            plt.axis("equal")
            plt.xlim((-0.1, 1.2))
            plt.ylim((-0.1, 1.0))
            plt.axis("off")
            center = (0.5, math.sqrt(3) / 6)
        else:
            all_coords = labels.keys()
            miny = min([c[1] for c in all_coords])
            ybuffer = max(abs(miny) * 0.1, 0.1)
            plt.xlim((-0.1, 1.1))
            plt.ylim((miny - ybuffer, ybuffer))
            center = (0.5, miny / 2)
            plt.xlabel("Fraction", fontsize=28, fontweight='bold')
            plt.ylabel("Formation energy (eV/fu)",
                       fontsize=28,
                       fontweight='bold')

        for coords in sorted(labels.keys(), key=lambda x: -x[1]):
            entry = labels[coords]
            label = entry.name

            # The follow defines an offset for the annotation text emanating
            # from the center of the PD. Results in fairly nice layouts for the
            # most part.
            vec = (np.array(coords) - center)
            vec = vec / np.linalg.norm(vec) * 10 if np.linalg.norm(vec) != 0 \
                else vec
            valign = "bottom" if vec[1] > 0 else "top"
            if vec[0] < -0.01:
                halign = "right"
            elif vec[0] > 0.01:
                halign = "left"
            else:
                halign = "center"
            if label_stable:
                if process_attributes and entry.attribute == 'new':
                    plt.annotate(latexify(label),
                                 coords,
                                 xytext=vec,
                                 textcoords="offset points",
                                 horizontalalignment=halign,
                                 verticalalignment=valign,
                                 fontproperties=font,
                                 color='g')
                else:
                    plt.annotate(latexify(label),
                                 coords,
                                 xytext=vec,
                                 textcoords="offset points",
                                 horizontalalignment=halign,
                                 verticalalignment=valign,
                                 fontproperties=font)

        if self.show_unstable:
            font = FontProperties()
            font.set_size(16)
            pda = PDAnalyzer(self._pd)
            energies_unstable = [
                pda.get_e_above_hull(entry)
                for entry, coord in unstable.iteritems()
            ]
            if energy_colormap is not None:
                energies.extend(energies_unstable)
                vals_unstable = _map.to_rgba(energies_unstable)
            ii = 0
            for entry, coords in unstable.items():
                vec = (np.array(coords) - center)
                vec = vec / np.linalg.norm(vec) * 10 \
                    if np.linalg.norm(vec) != 0 else vec
                label = entry.name
                if energy_colormap is None:
                    plt.plot(coords[0],
                             coords[1],
                             "ks",
                             linewidth=3,
                             markeredgecolor="k",
                             markerfacecolor="r",
                             markersize=8)
                else:
                    plt.plot(coords[0],
                             coords[1],
                             "s",
                             linewidth=3,
                             markeredgecolor="k",
                             markerfacecolor=vals_unstable[ii],
                             markersize=8)
                if label_unstable:
                    plt.annotate(latexify(label),
                                 coords,
                                 xytext=vec,
                                 textcoords="offset points",
                                 horizontalalignment=halign,
                                 color="b",
                                 verticalalignment=valign,
                                 fontproperties=font)
                ii += 1
        if energy_colormap is not None and show_colorbar:
            _map.set_array(energies)
            cbar = plt.colorbar(_map)
            cbar.set_label(
                'Energy [meV/at] above hull (in red)\nInverse energy ['
                'meV/at] above hull (in green)',
                rotation=-90,
                ha='left',
                va='center')
            ticks = cbar.ax.get_yticklabels()
            cbar.ax.set_yticklabels([
                '${v}$'.format(v=float(t.get_text().strip('$')) * 1000.0)
                for t in ticks
            ])
        f = plt.gcf()
        f.set_size_inches((8, 6))
        plt.subplots_adjust(left=0.09, right=0.98, top=0.98, bottom=0.07)
        return plt
Exemple #59
0
R_Nfig.savefig(out_dir + base_name + '_loadings_Ri_N.svg', format='svg')
plt.close()
R_Qfig.savefig(out_dir + base_name + '_loadings_Ri_Q.svg', format='svg')
plt.close()

# -------------------------------------------------------------------
# PCA PLOTS AND REGRESSION

Yvect = merged_deamid_mat.Ydata['Substrate']
Yset = set(Yvect)
Yset = np.sort(list(Yset))

# Map to color
keyCm = plt.get_cmap('tab20')
keyNorm = plt.Normalize(vmin=0, vmax=2 * len(Yset))
keySm = ScalarMappable(norm=keyNorm, cmap=keyCm)
keySm.set_array([])
map_color = {}
i = 1
for Yval in Yset:
    map_color[Yval] = keySm.to_rgba(i)
    i += 1

# -------------------------------------------------------------------
### REGRESSION USING PC1 on Thermal AGE WITH BONE AND
# TAR SEEP SAMPLES
# Get known age samples
ages = merged_deamid_mat.Ydata['10C Thermal age'].astype('float')
mask = np.logical_or(Yvect == 'Bone', Yvect == 'Tar seep bone')
mask = np.logical_or(mask, Yvect == 'Dental calculus')
ages_b = ages[mask]
Exemple #60
0
    coadmin = df['coadmin'] + np.random.rand(n)
    inter = df['inter'] + np.random.rand(n)

    norm = mpl.colors.Normalize(vmin=0, vmax=80)
    darkred = '#d62728'
    lightred = '#fae9e9'
    darkblue = '#1f77b4'
    lightblue = '#e8f1f7'

    cmap_female = LinearSegmentedColormap.from_list('female',
                                                    [lightred, darkred])
    cmap_male = LinearSegmentedColormap.from_list('male',
                                                  [lightblue, darkblue])
    cmap_gray = LinearSegmentedColormap.from_list('gray',
                                                  ['#f2f2f2', '#7f7f7f'])
    scmap_female = ScalarMappable(norm=norm, cmap=cmap_female)
    scmap_male = ScalarMappable(norm=norm, cmap=cmap_male)
    scmap_gray = ScalarMappable(norm=norm, cmap=cmap_gray)

    cmap_female.set_over(darkred)
    cmap_male.set_over(darkblue)

    scmap_male.set_array(df['age'].values)
    scmap_female.set_array(df['age'].values)
    scmap_gray.set_array(df['age'].values)

    #
    df['color'] = df.swifter.apply(calc_color_per_gender_and_age,
                                   axis='columns',
                                   args=(
                                       scmap_female,