コード例 #1
0
def makePlots(months, title, months_list, colorbar_title, vmin, vmax, vn, color, rnd=3, lndclr='gray'):
    """
    Plotting function for climatology, trend and anomaly maps
    """
    print("Plotting in progress")
    ticks = np.round(np.linspace(vmin,vmax,vn), rnd)
    
    color = plt.cm.get_cmap(color, vn-1)
    #print(len(color))
    color.set_bad(color='yellow')
    # plotting parameters
    parallels = np.arange(0,90,10.)
    meridians = np.arange(0.,360.,30.)
    
    indx = 0
    for month in months:
        plt.figure()
        mp = draw_map()
        plt.title("%s for %s (1979-2018)" %(title, months_list[indx]), y=1.08, fontsize=17)
        mp.imshow(month, origin = 'lower', norm=DivergingNorm(0), cmap = color, vmin = vmin, vmax = vmax)
        cbar = mp.colorbar(pad=0.6, ticks=ticks)
        cbar.ax.set_ylabel(colorbar_title, fontsize=14)
        #mp.drawlsmask(land_color='firebrick',ocean_color='aqua', lakes=True, alpha=0.3)
        mp.drawlsmask(land_color=lndclr,ocean_color='white', lakes=True, alpha=0.2)
        mp.drawparallels(parallels, labels = [0,0,0,0])
        mp.drawmeridians(meridians, labels = [1,1,1,1])
        indx += 1
    plt.show()
コード例 #2
0
ファイル: src_utils.py プロジェクト: nguyenvanessa/src
def plot_heatmap(heatmap, title):
    """
    Plot heatmap.
    
    Args:
        heatmap: 2D array where positions=rows, aas=cols
        title: title of plot
    Returns:
        fig: matplotlib.pyplot figure object
        ax: matplotlib.pyplot axis object

    """
    fig, ax = plt.subplots(figsize=(50, 300))
    resid_map = plt.imshow(heatmap.T, cmap='bwr', norm=DivergingNorm(0.0))

    # Set tick locations
    ax.set_yticks(np.arange(heatmap.shape[1]))
    ax.set_xticks(np.arange(heatmap.shape[0]))

    # Set tick labels
    ax.set_yticklabels(__aa_idx_dict__.keys())
    ax.set_xticklabels(__pos_idx_dict__.keys())
    plt.xticks(rotation='vertical')

    # Set title
    plt.title(title)

    # Show figure
    plt.show()

    return (fig, ax)
コード例 #3
0
ファイル: extremes.py プロジェクト: teslakit/teslakit
def Plot_GEVParams(xda_gev_var, c_shape='bwr', c_other='hot_r', show=True):
    'Plot GEV params for a GEV parameter variable (sea_Hs, swell_1_Hs, ...)'

    name = xda_gev_var.name
    params = xda_gev_var.parameter.values[:]
    ss = int(np.sqrt(len(
        xda_gev_var.n_cluster)))  # this will fail if cant sqrt

    # plot figure
    fig, axs = plt.subplots(2, 2, figsize=(_faspect * _fsize, _fsize))
    axs = [i for sl in axs for i in sl]

    # empty last axis
    axs[3].axis('off')

    for c, par in enumerate(params):
        ax = axs[c]

        par_values = xda_gev_var.sel(parameter=par).values[:]
        par_values[par_values == 1.0e-10] = 0

        rr_pv = np.flipud(np.reshape(par_values, (ss, ss)).T)

        if par == 'shape':
            cl = [np.min(par_values), np.max(par_values)]
            if cl[0] >= 0: cl[0] = -0.000000001
            if cl[1] <= 0: cl[1] = +0.000000001
            norm = DivergingNorm(vmin=cl[0], vcenter=0, vmax=cl[1])
            cma = c_shape

        else:
            cl = [np.min(par_values), np.max(par_values)]
            norm = None
            cma = c_other

        cc = ax.pcolor(rr_pv,
                       cmap=cma,
                       vmin=cl[0],
                       vmax=cl[1],
                       norm=norm,
                       edgecolor='k')
        fig.colorbar(cc, ax=ax)

        # add grid and title
        ax.set_title(par)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xticks([])
        ax.set_yticks([])

    ttl = 'GEV - {0}'.format(name)
    fig.suptitle(ttl, fontweight='bold', fontsize=14)

    # show
    if show: plt.show()
    return fig
コード例 #4
0
def plot_figure(data_ref, data, delta, levels, sup_title, unit, fmt, ndx, data_latitude, save_name):

    fig, ax = plt.subplots(nrows=1, ncols=3, sharey='col', figsize=(11, 8))
    fig.subplots_adjust(hspace=0, wspace=0, bottom=0.2)
    fig.suptitle(sup_title)

    dim1, dim2 = where(data_ref == 0)
    data_ref[dim1, dim2] = None

    dim1, dim2 = where(data == 0)
    data[dim1, dim2] = None

    dim1, dim2 = where(delta == 0)
    delta[dim1, dim2] = None

    ax[0].set_title('Simu ref (MV)')
    pc1 = ax[0].contourf(data_ref, levels=levels, cmap='plasma')
    ax[1].set_facecolor('white')

    ax[1].set_title('Our simu')
    ax[1].contourf(data, levels=levels, cmap='plasma')
    ax[1].set_facecolor('white')

    ax[2].set_title('Relative change (simu - ref)')
    pc2 = ax[2].contourf(delta, norm=DivergingNorm(vmin=-100, vcenter=0, vmax=100), levels=arange(-10, 12, 2) * 10,
                         cmap='seismic')

    ax[0].set_yticks(ticks=arange(0, len(data_latitude), 6))
    ax[0].set_yticklabels(labels=data_latitude[::6])
    ax[0].set_xticks(ticks=ndx)
    ax[0].set_xticklabels(labels=[0, 90, 180, 270, 359])
    ax[1].set_xticklabels(labels=['', 90, 180, 270, 359])
    ax[2].set_xticklabels(labels=['', 90, 180, 270, 359])

    pos1 = ax[0].get_position()
    pos3 = ax[2].get_position()
    cbar_ax1 = fig.add_axes([pos1.x0 + 0.02, 0.05, pos3.x0 - pos1.x0 - 0.04, 0.03])
    cbar1 = fig.colorbar(pc1, cax=cbar_ax1, orientation="horizontal", format=fmt)
    cbar1.ax.set_title(unit)

    cbar_ax2 = fig.add_axes([pos3.x0 + 0.02, 0.05, pos3.x1 - pos3.x0 - 0.04, 0.03])
    cbar2 = fig.colorbar(pc2, cax=cbar_ax2, orientation="horizontal")
    cbar2.ax.set_title('%')

    fig.text(0.06, 0.5, 'Latitude (°N)', ha='center', va='center', rotation='vertical', fontsize=14)
    fig.text(0.5, 0.15, 'Solar longitude (°)', ha='center', va='center', fontsize=14)
    plt.savefig(save_name + '.png', bbox_inches='tight')
    plt.close(fig)
コード例 #5
0
def makePlots(datasets, titles, vmin, vmax, delta, colorbar_title, split = 3, sup_title = [],res=2, color='seismic'):
    """
    datas should be of length 3, arranged in the order of dataset for June, July and August
    Useful for generating climatology, trend maps, anomaly maps and composites 
    """
    plt.figure()
    if len(sup_title)!=0:
        plt.suptitle(sup_title, fontsize=18)
        y=1.08
    else:
        y=1
    subplots = len(datasets)
    
    ticks = np.arange(vmin, vmax+delta, delta)
    colors = plt.cm.get_cmap(color, len(ticks)-1)
    parallels = np.arange(25,90,10.)
    meridians = np.arange(0.,360.,30.)
    for i in np.arange(subplots):
        plt.subplot(split,int(subplots/split),i+1)
        plt.title(titles[i], y=y, fontsize=17)
        
        m = Basemap(projection='cyl',llcrnrlat=25,urcrnrlat=90,\
                llcrnrlon=0,urcrnrlon=360,resolution='c')
        m.imshow(datasets[i], norm=DivergingNorm(0.05), origin ='upper', vmin = vmin, vmax = vmax, cmap=colors)
        m.drawcoastlines()
#        t_s = ticks[0::res]
#        if t_s[-1]==ticks[-1]:
#            cbar = m.colorbar(ticks=t_s)
#        else:
#            cbar = m.colorbar(ticks=ticks[1:-1:res])
#        cbar.ax.set_title(colorbar_title, fontsize=14)
        
        m.drawparallels(parallels, labels = [1,0,0,0])
        m.drawmeridians(meridians, labels = [0,0,0,1])
        m.drawmapboundary(fill_color='white')
        
    plt.subplots_adjust(left = 0.1, bottom=0.1, right=0.8, top=0.9, hspace = 0.3)
    #[left, bottom, width, height]    
    cax = plt.axes([0.83, 0.1, 0.01, 0.8])
    t_s = ticks[0::res]
    if t_s[-1]==ticks[-1]:
        cbar = plt.colorbar(cax=cax, orientation = "vertical", ticks=t_s)
    else:
        cbar = plt.colorbar(cax=cax, orientation = "vertical", ticks=ticks[1:-1:res])
    cbar.ax.set_ylabel(colorbar_title, fontsize=14)
    
    plt.show()
    return
コード例 #6
0
ファイル: test_colorbar.py プロジェクト: wandyzyh/matplotlib
def test_extend_colorbar_customnorm():
    # This was a funny error with DivergingNorm, maybe with other norms,
    # when extend='both'
    N = 100
    X, Y = np.mgrid[-3:3:complex(0, N), -2:2:complex(0, N)]
    Z1 = np.exp(-X**2 - Y**2)
    Z2 = np.exp(-(X - 1)**2 - (Y - 1)**2)
    Z = (Z1 - Z2) * 2

    fig, ax = plt.subplots(2, 1)
    pcm = ax[0].pcolormesh(X, Y, Z,
                           norm=DivergingNorm(vcenter=0., vmin=-2, vmax=1),
                           cmap='RdBu_r')
    cb = fig.colorbar(pcm, ax=ax[0], extend='both')
    np.testing.assert_allclose(cb.ax.get_position().extents,
                               [0.78375, 0.536364, 0.796147, 0.9], rtol=1e-3)
コード例 #7
0
def heatmap(fitness, string):
    from matplotlib.colors import DivergingNorm
    fitness = np.array(fitness)
    fitness[np.isnan(fitness)] = 0
    data = np.round(fitness, 1)

    fig, ax = plt.subplots()
    fig.set_figheight(10)
    fig.set_figwidth(10)
    im = ax.imshow(data)

    names = [str(i + 1) for i in range(len(fitness))]
    subjects = [str(i + 1) for i in range(len(fitness[0]))]

    # We want to show all ticks...
    ax.set_xticks(np.arange(len(names)))
    ax.set_yticks(np.arange(len(subjects)))
    # ... and label them with the respective list entries
    ax.set_xticklabels(names)
    ax.set_yticklabels(subjects)

    # Rotate the tick labels and set their alignment.
    # plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
    #         rotation_mode="anchor")

    # save this plot inside a variable called hm
    hm = plt.imshow(data,
                    norm=DivergingNorm(0),
                    cmap=cm.coolwarm,
                    interpolation="nearest")
    # pass this heatmap object into plt.colorbar method.
    for i in range(len(fitness)):
        for j in range(len(fitness[0])):
            text = ax.text(j,
                           i,
                           data[i, j],
                           ha="center",
                           va="center",
                           color="w")

    ax.set_title(string + "Fitness in feature space")
    fig.tight_layout()
    #plt.clim(-fitness.max(),fitness.max())
    plt.xlabel('feature dimension 1')
    plt.ylabel('feature dimension 2')
    plt.colorbar()
    plt.show()
コード例 #8
0
def makeCompositeplots(months, sup_title, subtitles, colorbar_title, vmin, vmax, vn, rnd, color):#, res=1):
    # Plotting the Composite Maps

    # July-c1 dominated years
    ticks = np.round(np.linspace(vmin,vmax,vn),rnd)
    color = plt.cm.get_cmap(color, vn-1)
    # plotting parameters
    parallels = np.arange(0,90,10.)
    meridians = np.arange(0.,360.,30.)
    plt.figure()
    plt.suptitle(sup_title, fontsize=18)
    i=0
    for month in months:
        mp=draw_map()
        plt.subplot(2,2,i+1)
        plt.title(subtitles[i], y=1.08, fontsize=17)
        mp.imshow(month, norm = DivergingNorm(0.05), vmax = vmax, vmin = vmin, origin = 'lower', cmap = color)
        #mp.drawcoastlines()
        mp.colorbar(pad=0.6, ticks = ticks)
        mp.drawlsmask(land_color='gray', ocean_color='white', lakes=True, alpha=0.1)
        mp.drawparallels(parallels, labels = [0,0,0,0])
        mp.drawmeridians(meridians, labels = [1,1,1,1])
        #mp.drawmapboundary(color='black')
        i+=1
        
    plt.subplots_adjust(left = 0.1, bottom=0.1, right=0.8, top=0.85, hspace = 0.3)
    #[left, bottom, width, height]
    cax = plt.axes([0.83, 0.1, 0.01, 0.75])
    res = vn
# =============================================================================
#    t_s = ticks[0::res]
#    if t_s[-1]==ticks[-1]:
#        cbar = plt.colorbar(cax=cax, orientation = "vertical", ticks=t_s)
#    else:
#        cbar = plt.colorbar(cax=cax, orientation = "vertical", ticks=ticks[1:-1:res])
# =============================================================================
    #cbar = plt.colorbar(cax=cax, orientation = "vertical", ticks=ticks[0::res])
#    cbar.ax.set_ylabel(colorbar_title, fontsize=14)
    plt.show()
    
    return
コード例 #9
0
def horizontal_map(variable_name, date, start_hour, 
                                  end_hour, pressure_level=False, 
                                  subset=False, initiation=False, 
                                  save=False, gif=False):
    
    '''This function plots the chosen variable for the analysis 
    of the initiation environment on a horizontal (2D) map. Supported variables for plotting 
    procedure are updraft, reflectivity, helicity, pw, cape, cin, ctt, temperature_surface, 
    wind_shear, updraft_reflectivity, rh, omega, pvo, avo, theta_e, water_vapor, uv_wind and 
    divergence.'''
    
    ### Predefine some variables ###
    
    # Get the list of all needed wrf files
    data_dir = '/scratch3/thomasl/work/data/casestudy_baden/'
    
    # Define save directory
    save_dir = '/scratch3/thomasl/work/retrospective_part'                '/casestudy_baden/horizontal_maps/'

    # Change extent of plot
    subset_extent = [6.2, 9.4, 46.5, 48.5]
    
    # Set the location of the initiation of the thunderstorm
    initiation_location = CoordPair(lat=47.25, lon=7.85)

    # 2D variables:
    if variable_name == 'updraft':
        variable_name = 'W_UP_MAX'
        title_name = 'Maximum Z-Wind Updraft'
        colorbar_label = 'Max Z-Wind Updraft [$m$ $s^-$$^1$]'
        save_name = 'updraft'
        variable_min = 0
        variable_max = 30
        
        # Check if a certain pressure_level was defined.
        if pressure_level != False: 
            sys.exit('The variable {} is a 2D variable. '                      'Definition of a pressure_level for '                      'plotting process is not required.'.format(variable_name))
        
    elif variable_name == 'reflectivity':
        variable_name = 'REFD_MAX'
        title_name = 'Maximum Derived Radar Reflectivity'
        colorbar_label = 'Maximum Derived Radar Reflectivity [$dBZ$]'
        save_name = 'reflectivity'
        variable_min = 0
        variable_max = 75
        
        # Check if a certain pressure_level was defined.
        if pressure_level != False: 
            sys.exit('The variable {} is a 2D variable. '                      'Definition of a pressure_level for '                      'plotting process is not required.'.format(variable_name))
        
    elif variable_name == 'helicity':
        variable_name = 'UP_HELI_MAX'
        title_name = 'Maximum Updraft Helicity'
        colorbar_label = 'Maximum Updraft Helicity [$m^{2}$ $s^{-2}$]'
        save_name = 'helicity'
        variable_min = 0 
        variable_max = 140
        
        # Check if a certain pressure_level was defined.
        if pressure_level != False: 
            sys.exit('The variable {} is a 2D variable. '                      'Definition of a pressure_level for '                      'plotting process is not required.'.format(variable_name))
        
    elif variable_name == 'pw':
        title_name = 'Precipitable Water'
        colorbar_label = 'Precipitable Water [$kg$ $m^{-2}$]'
        save_name = 'pw'
        variable_min = 0 
        variable_max = 50 
        
        # Check if a certain pressure_level was defined.
        if pressure_level != False: 
            sys.exit('The variable {} is a 2D variable. '                      'Definition of a pressure_level for '                      'plotting process is not required.'.format(variable_name))
    
    elif variable_name == 'cape':
        variable_name = 'cape_2d'
        title_name = 'CAPE'
        colorbar_label = 'Convective Available Potential Energy'                             '[$J$ $kg^{-1}$]'
        save_name = 'cape'
        variable_min = 0 
        variable_max = 3000 
        
        # Check if a certain pressure_level was defined.
        if pressure_level != False: 
            sys.exit('The variable {} is a 2D variable. '                      'Definition of a pressure_level for '                      'plotting process is not required.'.format(variable_name))
        
    elif variable_name == 'cin':
        variable_name = 'cape_2d'
        title_name = 'CIN'
        colorbar_label = 'Convective Inhibition [$J$ $kg^{-1}$]'
        save_name = 'cin'
        variable_min = 0
        variable_max = 100 

        # Check if a certain pressure_level was defined.
        if pressure_level != False: 
            sys.exit('The variable {} is a 2D variable. '                      'Definition of a pressure_level for '                      'plotting process is not required.'.format(variable_name))
        
    elif variable_name == 'ctt':
        title_name = 'Cloud Top Temperature'
        colorbar_label = 'Cloud Top Temperature [$K$]'
        save_name = 'cct'
        variable_min = 210 
        variable_max = 300 
        
        # Check if a certain pressure_level was defined.
        if pressure_level != False: 
            sys.exit('The variable {} is a 2D variable. '                      'Definition of a pressure_level for '                      'plotting process is not required.'.format(variable_name))
    
    elif variable_name == 'temperature_surface':
        variable_name = 'T2'
        title_name = 'Temperature @ 2 m'
        colorbar_label = 'Temperature [$K$]'
        save_name = 'temperature_surface'
        variable_min = 285
        variable_max = 305

        # Check if a certain pressure_level was defined.
        if pressure_level != False: 
            sys.exit('The variable {} is a 2D variable. '                      'Definition of a pressure_level for '                      'plotting process is not required.'.format(variable_name))
            
    elif variable_name == 'wind_shear':
        variable_name = 'slp'
        title_name = 'SLP, Wind @ 850hPa, Wind @ 500hPa\n'                         'and 500-850hPa Vertical Wind Shear'
        save_name = 'wind_shear'
        variable_min = 1000
        variable_max = 1020

        # Check if a certain pressure_level was defined.
        if pressure_level != False: 
            sys.exit('The variable {} is a 2D variable. '                      'Definition of a pressure_level for '                      'plotting process is not required.'.format(variable_name))
            
    elif variable_name == 'updraft_reflectivity':
        variable_name = 'W_UP_MAX'
        title_name = 'Updraft and Reflectivity'
        colorbar_label = 'Max Z-Wind Updraft [$m$ $s^-$$^1$]'
        save_name = 'updraft_reflectivity'
        variable_min = 0
        variable_max = 30
        
        # Check if a certain pressure_level was defined.
        if pressure_level != False: 
            sys.exit('The variable {} is a 2D variable. '                      'Definition of a pressure_level for '                      'plotting process is not required.'.format(variable_name))
            
    # 3D variables:
    elif variable_name == 'rh':
        title_name = 'Relative Humidity'
        colorbar_label = 'Relative Humidity [$pct$]'
        save_name = 'rh'
        variable_min = 0
        variable_max = 100
        
        # Check if a certain pressure_level was defined.
        if pressure_level == False: 
            sys.exit('The variable {} is a 3D variable. '                      'Definition of a pressure_level for '                      'plotting process is required.'.format(variable_name))
        
    elif variable_name == 'omega':
        title_name = 'Vertical Motion'
        colorbar_label = 'Omega [$Pa$ $s^-$$^1$]'
        save_name = 'omega'
        variable_min = -50
        variable_max = 50
        
        # Check if a certain pressure_level was defined.
        if pressure_level == False: 
            sys.exit('The variable {} is a 3D variable. '                      'Definition of a pressure_level for '                      'plotting process is required.'.format(variable_name))
            
    elif variable_name == 'pvo':
        title_name = 'Potential Vorticity'
        colorbar_label = 'Potential Vorticity [$PVU$]'
        save_name = 'pvo'
        variable_min = -1 
        variable_max = 9 
        
        # Check if a certain pressure_level was defined.
        if pressure_level == False: 
            sys.exit('The variable {} is a 3D variable. '                      'Definition of a pressure_level for '                      'plotting process is required.'.format(variable_name))
            
    elif variable_name == 'avo':
        title_name = 'Absolute Vorticity'
        colorbar_label = 'Absolute Vorticity [$10^{-5}$'                             '$s^{-1}$]'
        save_name = 'avo'
        variable_min = -250
        variable_max = 250 
        
        # Check if a certain pressure_level was defined.
        if pressure_level == False: 
            sys.exit('The variable {} is a 3D variable. '                      'Definition of a pressure_level for '                      'plotting process is required.'.format(variable_name))
    
    elif variable_name == 'theta_e':
        title_name = 'Theta-E'
        colorbar_label = 'Theta-E [$K$]'
        save_name = 'theta_e'
        variable_min = 315
        variable_max = 335 
        
        # Check if a certain pressure_level was defined.
        if pressure_level == False: 
            sys.exit('The variable {} is a 3D variable. '                      'Definition of a pressure_level for '                      'plotting process is required.'.format(variable_name))
            
    elif variable_name == 'water_vapor':
        variable_name = 'QVAPOR'
        title_name = 'Water Vapor Mixing Ratio'
        colorbar_label = 'Water Vapor Mixing Ratio [$g$ $kg^{-1}$]'
        save_name = 'water_vapor'
        variable_min = 5
        variable_max = 15
        
        # Check if a certain pressure_level was defined.
        if pressure_level == False: 
            sys.exit('The variable {} is a 3D variable. '                      'Definition of a pressure_level for '                      'plotting process is required.'.format(variable_name))
    
    elif variable_name == 'uv_wind':
        variable_name = 'wspd_wdir'
        title_name = 'Wind Speed and Direction'
        colorbar_label = 'Wind Speed [$m$ $s^{-1}$]'
        save_name = 'uv_wind'
        variable_min = 0
        variable_max = 10 

        # Check if a certain pressure_level was defined.
        if pressure_level == False: 
            sys.exit('The variable {} is a 3D variable. '                      'Definition of a pressure_level for '                      'plotting process is required.'.format(variable_name))
        
    elif variable_name == 'divergence':
        variable_name = 'ua'
        title_name = 'Horizontal Wind Divergence'
        colorbar_label = 'Divergence [$10^{-6}$ $s^{-1}$]'
        save_name = 'divergence'
        variable_min = -2.5
        variable_max = 2.5
            
        # Check if a certain pressure_level was defined.
        if pressure_level == False: 
            sys.exit('The variable {} is a 3D variable. '                      'Definition of a pressure_level for '                      'plotting process is required.'.format(variable_name))
    
    # Make a list of all wrf files in data directory
    wrflist = list()
    for (dirpath, dirnames, filenames) in os.walk(data_dir):
        wrflist += [os.path.join(dirpath, file) for file in filenames]
    
    ### Plotting Iteration ###
    
    # Iterate over a list of hourly timesteps
    time = list()
    for i in range(start_hour, end_hour):
        time = str(i).zfill(2)

        # Iterate over all 5 minutes steps of hour
        for j in range(0, 60, 5):
            minutes = str(j).zfill(2)
                
            # Load the netCDF files out of the wrflist
            ncfile = [Dataset(x) for x in wrflist
                if x.endswith('{}_{}:{}:00'.format(date, time, minutes))]
            
            # Load variable(s)
            if title_name == 'CAPE':
                variable = getvar(ncfile, variable_name)[0,:]
                
            elif title_name == 'CIN':
                variable = getvar(ncfile, variable_name)[1,:]
                
            elif variable_name == 'ctt':
                variable = getvar(ncfile, variable_name, units='K')
                
            elif variable_name == 'wspd_wdir':
                variable = getvar(ncfile, variable_name)[0,:]
            
            elif variable_name == 'QVAPOR':
                variable = getvar(ncfile, variable_name)*1000 # convert to g/kg
                    
            else:
                variable = getvar(ncfile, variable_name)

            if variable_name == 'slp':
                slp = variable.squeeze()
                
                ua = getvar(ncfile, 'ua')
                va = getvar(ncfile, 'va')

                p = getvar(ncfile, 'pressure')

                u_wind850 = interplevel(ua, p, 850)
                v_wind850 = interplevel(va, p, 850)

                u_wind850 = u_wind850.squeeze()
                v_wind850 = v_wind850.squeeze()

                u_wind500 = interplevel(ua, p, 500)
                v_wind500 = interplevel(va, p, 500)

                u_wind500 = u_wind500.squeeze()
                v_wind500 = v_wind500.squeeze()

                slp = ndimage.gaussian_filter(slp, sigma=3, order=0)
            
            # Interpolating 3d data to a horizontal pressure level
            if pressure_level != False:
                p = getvar(ncfile, 'pressure')
                variable_pressure = interplevel(variable, p, 
                                                pressure_level)
                variable = variable_pressure
                
            if variable_name == 'wspd_wdir':
                ua = getvar(ncfile, 'ua')
                va = getvar(ncfile, 'va')
                u_pressure = interplevel(ua, p, pressure_level)
                v_pressure = interplevel(va, p, pressure_level)
                
            elif title_name == 'Updraft and Reflectivity':
                reflectivity = getvar(ncfile, 'REFD_MAX')
                
            elif title_name == 'Difference in Theta-E values':
                variable = getvar(ncfile, variable_name)
                
                p = getvar(ncfile, 'pressure')
                variable_pressure1 = interplevel(variable, p, '950')
                variable_pressure2 = interplevel(variable, p, '950')
                
            elif variable_name == 'ua':
                va = getvar(ncfile, 'va')

                p = getvar(ncfile, 'pressure')

                v_pressure = interplevel(va, p, pressure_level)

                u_wind = variable.squeeze()
                v_wind = v_pressure.squeeze()

                u_wind.attrs['units']='meters/second'
                v_wind.attrs['units']='meters/second'
                
                lats, lons = latlon_coords(variable)
                lats = lats.squeeze()
                lons = lons.squeeze()

                dx, dy = mpcalc.lat_lon_grid_deltas(to_np(lons), to_np(lats))

                divergence = mpcalc.divergence(u_wind, v_wind, dx, dy, dim_order='yx')
                divergence = divergence*1e3


            # Define cart projection
            lats, lons = latlon_coords(variable)
            cart_proj = ccrs.LambertConformal(central_longitude=8.722206, 
                                    central_latitude=46.73585)

            bounds = geo_bounds(wrfin=ncfile)

            # Create figure
            fig = plt.figure(figsize=(15, 10))

            if variable_name == 'slp':
                fig.patch.set_facecolor('k')

            ax = plt.axes(projection=cart_proj)

            ### Set map extent ###
            domain_extent = [3.701088, 13.814863, 43.85472,49.49499]

            if subset == True:
                ax.set_extent([subset_extent[0],subset_extent[1],
                               subset_extent[2],subset_extent[3]],
                                 ccrs.PlateCarree())
                
            else: 
                ax.set_extent([domain_extent[0]+0.7,domain_extent[1]-0.7,
                               domain_extent[2]+0.1,domain_extent[3]-0.1],
                                 ccrs.PlateCarree())

            # Plot contour of variables
            levels_num = 11
            levels = np.linspace(variable_min, variable_max, levels_num)
            
            # Creating new colormap for diverging colormaps
            def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
                    new_cmap = LinearSegmentedColormap.from_list(
                        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, 
                                                            b=maxval),
                        cmap(np.linspace(minval, maxval, n)))
                    return new_cmap
            
            cmap = plt.get_cmap('RdYlBu')
            
            if title_name == 'CIN':
                cmap = ListedColormap(sns.cubehelix_palette(levels_num-1, 
                                        start=.5, rot=-.75, reverse=True))
                variable_plot = plt.contourf(to_np(lons), to_np(lats), to_np(variable), 
                                 levels=levels, transform=ccrs.PlateCarree(), extend='max', 
                                 cmap=cmap)
                initiation_color = 'r*'
                
            elif variable_name == 'ctt':
                cmap = ListedColormap(sns.cubehelix_palette(levels_num-1, 
                                        start=.5, rot=-.75, reverse=True))
                variable_plot = plt.contourf(to_np(lons), to_np(lats), to_np(variable), 
                                 levels=levels, transform=ccrs.PlateCarree(), extend='both', 
                                 cmap=cmap)
                initiation_color = 'r*'
                
            elif variable_name == 'pvo':
                cmap = plt.get_cmap('RdYlBu_r')
                new_cmap = truncate_colormap(cmap, 0.05, 0.9)
                new_norm = DivergingNorm(vmin=-1., vcenter=2., vmax=10)
                
                variable_plot = plt.contourf(to_np(lons), to_np(lats), to_np(variable), 
                                 levels=levels, transform=ccrs.PlateCarree(), 
                                 cmap=new_cmap, extend='both', norm=new_norm)
                initiation_color = 'k*'
                
            elif variable_name == 'avo':
                cmap = plt.get_cmap('RdYlBu_r')
                new_cmap = truncate_colormap(cmap, 0.05, 0.9)
                new_norm = DivergingNorm(vmin=variable_min, vcenter=0, vmax=variable_max)
                
                variable_plot = plt.contourf(to_np(lons), to_np(lats), to_np(variable), 
                                 levels=levels, transform=ccrs.PlateCarree(), 
                                 cmap=new_cmap, extend='both', norm=new_norm)
                initiation_color = 'k*'
                
            elif variable_name == 'omega':
                new_cmap = truncate_colormap(cmap, 0.05, 0.9)
                new_norm = DivergingNorm(vmin=variable_min, vcenter=0, vmax=variable_max)

                variable_plot = plt.contourf(to_np(lons), to_np(lats), to_np(variable), 
                                 levels=levels, transform=ccrs.PlateCarree(), 
                                 cmap=new_cmap, extend='both', norm=new_norm)
                initiation_color = 'k*'
                
            elif variable_name == 'ua':
                new_cmap = truncate_colormap(cmap, 0.05, 0.9)
                new_norm = DivergingNorm(vmin=variable_min, vcenter=0, vmax=variable_max)

                variable_plot = plt.contourf(to_np(lons), to_np(lats), divergence, 
                                 levels=levels, transform=ccrs.PlateCarree(), 
                                 cmap=new_cmap, extend='both', norm=new_norm)
                initiation_color = 'k*'
                
            elif variable_name == 'UP_HELI_MAX' or variable_name == 'W_UP_MAX' or variable_name == 'QVAPOR':
                cmap = ListedColormap(sns.cubehelix_palette(levels_num-1, 
                                        start=.5, rot=-.75))
                variable_plot = plt.contourf(to_np(lons), to_np(lats), 
                                to_np(variable), levels=levels, extend='max',
                                transform=ccrs.PlateCarree(),cmap=cmap)
                initiation_color = 'r*'
                
            elif variable_name == 'theta_e' or variable_name == 't2':
                cmap = ListedColormap(sns.cubehelix_palette(levels_num-1, 
                                        start=.5, rot=-.75))
                variable_plot = plt.contourf(to_np(lons), to_np(lats), 
                                to_np(variable), levels=levels, extend='both',
                                transform=ccrs.PlateCarree(),cmap=cmap)
                initiation_color = 'r*'

                
            elif variable_name == 'REFD_MAX':
                levels = np.arange(5., 75., 5.)
                dbz_rgb = np.array([[4,233,231],
                                    [1,159,244], [3,0,244],
                                    [2,253,2], [1,197,1],
                                    [0,142,0], [253,248,2],
                                    [229,188,0], [253,149,0],
                                    [253,0,0], [212,0,0],
                                    [188,0,0],[248,0,253],
                                    [152,84,198]], np.float32) / 255.0
                dbz_cmap, dbz_norm = from_levels_and_colors(levels, dbz_rgb,
                                                           extend='max')
                
                variable_plot = plt.contourf(to_np(lons), to_np(lats), 
                                 to_np(variable), levels=levels, extend='max',
                                 transform=ccrs.PlateCarree(), cmap=dbz_cmap,
                                            norm=dbz_norm)
                initiation_color = 'r*'
                
            elif variable_name == 'slp':
                ax.background_patch.set_fill(False)
                    
                wslice = slice(1, None, 12)
                # Plot 850-hPa wind vectors
                vectors850 = ax.quiver(to_np(lons)[wslice, wslice], 
                                       to_np(lats)[wslice, wslice],
                                       to_np(u_wind850)[wslice, wslice], 
                                       to_np(v_wind850)[wslice, wslice],
                                       headlength=4, headwidth=3, scale=400, color='gold', 
                                       label='850mb wind', transform=ccrs.PlateCarree(), 
                                       zorder=2)

                # Plot 500-hPa wind vectors
                vectors500 = ax.quiver(to_np(lons)[wslice, wslice], 
                                       to_np(lats)[wslice, wslice],
                                       to_np(u_wind500)[wslice, wslice], 
                                       to_np(v_wind500)[wslice, wslice],
                                       headlength=4, headwidth=3, scale=400, 
                                       color='cornflowerblue', zorder=2,
                                       label='500mb wind', transform=ccrs.PlateCarree())

                # Plot 500-850 shear
                shear = ax.quiver(to_np(lons[wslice, wslice]), 
                                  to_np(lats[wslice, wslice]),
                                  to_np(u_wind500[wslice, wslice]) - 
                                  to_np(u_wind850[wslice, wslice]),
                                  to_np(v_wind500[wslice, wslice]) - 
                                  to_np(v_wind850[wslice, wslice]),
                                  headlength=4, headwidth=3, scale=400, 
                                  color='deeppink', zorder=2,
                                  label='500-850mb shear', transform=ccrs.PlateCarree())

                contour = ax.contour(to_np(lons), to_np(lats), slp, levels=levels, 
                                     colors='lime', linewidths=2, alpha=0.5, zorder=1,
                                     transform=ccrs.PlateCarree())
                ax.clabel(contour, fontsize=12, inline=1, inline_spacing=4, fmt='%i')
                
                # Add a legend
                ax.legend(('850mb wind', '500mb wind', '500-850mb shear'), loc=4)

                # Manually set colors for legend
                legend = ax.get_legend()
                legend.legendHandles[0].set_color('gold')
                legend.legendHandles[1].set_color('cornflowerblue')
                legend.legendHandles[2].set_color('deeppink')
                
                initiation_color = 'w*'
            
            else:
                cmap = ListedColormap(sns.cubehelix_palette(10, 
                                        start=.5, rot=-.75))
                variable_plot = plt.contourf(to_np(lons), to_np(lats), 
                                to_np(variable), levels=levels,
                                transform=ccrs.PlateCarree(),cmap=cmap)
                initiation_color = 'r*'
                         
            # Plot reflectivity contours with colorbar 
            if title_name == 'Updraft and Reflectivity':
                dbz_levels = np.arange(35., 75., 5.)
                dbz_rgb = np.array([[253,248,2],
                        [229,188,0], [253,149,0],
                        [253,0,0], [212,0,0],
                        [188,0,0],[248,0,253],
                        [152,84,198]], np.float32) / 255.0
                dbz_cmap, dbz_norm = from_levels_and_colors(dbz_levels, dbz_rgb,
                                               extend='max')                

                contours = plt.contour(to_np(lons), to_np(lats), 
                                           to_np(reflectivity), 
                                           levels=dbz_levels, 
                                           transform=ccrs.PlateCarree(), 
                                           cmap=dbz_cmap, norm=dbz_norm, 
                                           linewidths=1)

                cbar_refl = mpu.colorbar(contours, ax, orientation='horizontal', aspect=10, 
                                         shrink=.5, pad=0.05)
                cbar_refl.set_label('Maximum Derived Radar Reflectivity'                                         '[$dBZ$]', fontsize=12.5)
                colorbar_lines = cbar_refl.ax.get_children()
                colorbar_lines[0].set_linewidths([10]*5)
            
            # Add wind quivers for every 10th data point
            if variable_name == 'wspd_wdir':
                plt.quiver(to_np(lons[::10,::10]), to_np(lats[::10,::10]),
                            to_np(u_pressure[::10, ::10]), 
                            to_np(v_pressure[::10, ::10]),
                            transform=ccrs.PlateCarree())
            
            # Plot colorbar
            if variable_name == 'slp':
                pass
            else:
                cbar = mpu.colorbar(variable_plot, ax, orientation='vertical', aspect=40, 
                                    shrink=.05, pad=0.05)
                cbar.set_label(colorbar_label, fontsize=15)
                cbar.set_ticks(levels)
            
            # Add borders and coastlines
            if variable_name == 'slp':
                ax.add_feature(cfeature.BORDERS.with_scale('10m'), 
                           edgecolor='white', linewidth=2)
                ax.add_feature(cfeature.COASTLINE.with_scale('10m'), 
                           edgecolor='white', linewidth=2)
            else:
                ax.add_feature(cfeature.BORDERS.with_scale('10m'), 
                               linewidth=0.8)
                ax.add_feature(cfeature.COASTLINE.with_scale('10m'), 
                               linewidth=0.8)
            
            ### Add initiation location ###
            if initiation == True:
                ax.plot(initiation_location.lon, initiation_location.lat, 
                        initiation_color, markersize=20, transform=ccrs.PlateCarree())
            
            # Add gridlines
            lon = np.arange(0, 20, 1)
            lat = np.arange(40, 60, 1)

            gl = ax.gridlines(xlocs=lon, ylocs=lat, zorder=3)
            
            # Add tick labels
            mpu.yticklabels(lat, ax=ax, fontsize=12.5)
            mpu.xticklabels(lon, ax=ax, fontsize=12.5)
            
            # Make nicetime
            file_name = '{}wrfout_d02_{}_{}:{}:00'.format(data_dir, 
                                                          date, time, minutes)
            xr_file = xr.open_dataset(file_name)
            nicetime = pd.to_datetime(xr_file.QVAPOR.isel(Time=0).XTIME.values)
            nicetime = nicetime.strftime('%Y-%m-%d %H:%M')
            
            # Add plot title
            if pressure_level != False: 
                ax.set_title('{} @ {} hPa'.format(title_name, pressure_level), 
                             loc='left', fontsize=15)
                ax.set_title('Valid time: {} UTC'.format(nicetime), 
                             loc='right', fontsize=15)
            else:
                if variable_name == 'slp':
                    ax.set_title(title_name, loc='left', fontsize=15, color='white')
                    ax.set_title('Valid time: {} UTC'.format(nicetime), 
                                 loc='right', fontsize=15, color='white')
                else:
                    ax.set_title(title_name, loc='left', fontsize=20)
                    ax.set_title('Valid time: {} UTC'.format(nicetime), 
                                 loc='right', fontsize=15)

            plt.show()
            
            ### Save figure ###
            if save == True:
                if pressure_level != False: 
                    if subset == True:
                        fig.savefig('{}/{}/horizontal_map_{}_subset_{}_{}_{}:{}.png'.format(
                            save_dir, save_name, save_name, pressure_level, date, time, 
                            minutes), bbox_inches='tight', dpi=300)
                    else: 
                        fig.savefig('{}/{}/horizontal_map_{}_{}_{}_{}:{}.png'.format(
                            save_dir, save_name, save_name, pressure_level, date, time, 
                            minutes), bbox_inches='tight', dpi=300)
                
                else: 
                    if subset == True:
                        fig.savefig('{}/{}/horizontal_map_{}_subset_{}_{}:{}.png'.format(
                            save_dir, save_name, save_name, date, time, minutes),
                                    bbox_inches='tight', dpi=300, facecolor=fig.get_facecolor())
                    
                    else: 
                        fig.savefig('{}/{}/horizontal_map_{}_{}_{}:{}.png'.format(
                            save_dir, save_name, save_name, date, time, minutes), 
                                    bbox_inches='tight', dpi=300, facecolor=fig.get_facecolor())
        
    ### Make a GIF from the plots ###
    if gif == True: 
        # Predifine some variables
        gif_data_dir = save_dir + save_name
        gif_save_dir = '{}gifs/'.format(save_dir)
        gif_save_name = 'horizontal_map_{}.gif'.format(save_name)

        # GIF creating procedure
        os.chdir(gif_data_dir)

        image_folder = os.fsencode(gif_data_dir)

        filenames = []

        for file in os.listdir(image_folder):
            filename = os.fsdecode(file)
            if filename.endswith( ('.png') ):
                filenames.append(filename)

        filenames.sort()
        images = list(map(lambda filename: imageio.imread(filename), 
                          filenames))

        imageio.mimsave(os.path.join(gif_save_dir + gif_save_name), 
                        images, duration = 0.50)
コード例 #10
0
def map_validation_colorbar(rain, depths, calibrated, validated, road, demarr,
                            slopearr, failarr, failinterval, fig_height,
                            fig_width, fig_name):

    fig = plt.figure(1, facecolor='White', figsize=[fig_width, fig_height])
    ax1 = plt.subplot2grid((1, 1), (0, 0), colspan=1, rowspan=1)

    calib_arr = 0 * demarr
    for i in range(len(calibrated)):
        x = calibrated['col'].iloc[i]
        y = calibrated['row'].iloc[i]
        calib_arr[y - 2:y + 2, x - 2:x + 2] = 1

    valid_arr = 0 * demarr
    for i in range(len(validated)):
        x = int(validated['col'].iloc[i])
        y = int(validated['row'].iloc[i])
        if x >= 2 and y >= 2 and x <= len(
                demarr[0]) - 2 and y <= len(demarr) - 2:

            if validated['time_of_failure'].iloc[i] <= validated[
                    'observed_failtime'].iloc[i] + failinterval and validated[
                        'time_of_failure'].iloc[i] >= validated[
                            'observed_failtime'].iloc[i] - failinterval:
                valid_arr[y - 2:y + 2, x - 2:x + 2] = 4
            elif validated['time_of_failure'].iloc[
                    i] > validated['observed_failtime'].iloc[i] + failinterval:
                valid_arr[y - 2:y + 2, x - 2:x + 2] = 2
            elif validated['time_of_failure'].iloc[
                    i] < validated['observed_failtime'].iloc[i] - failinterval:
                valid_arr[y - 2:y + 2, x - 2:x + 2] = 1
            valid_arr[y - 2:y + 2, x - 2:x +
                      2] = (validated['time_of_failure'].iloc[i] -
                            validated['observed_failtime'].iloc[i]) / (24 *
                                                                       3600)

    dem_mask = np.ma.masked_where(demarr <= -10, demarr)
    ax1.add_line(road)
    Map1 = ax1.imshow(dem_mask,
                      interpolation='None',
                      cmap=plt.cm.Greys_r,
                      vmin=np.amin(dem_mask),
                      vmax=np.amax(dem_mask),
                      alpha=1.)

    valid_mask = np.ma.masked_where(valid_arr == -0, valid_arr)
    Map2 = ax1.imshow(valid_mask,
                      interpolation='None',
                      norm=DivergingNorm(0),
                      cmap=plt.cm.jet,
                      vmin=np.amin(valid_mask),
                      vmax=np.amax(valid_mask),
                      alpha=1.)

    calib_mask = np.ma.masked_where(calib_arr == 0., calib_arr)
    Map1 = ax1.imshow(calib_mask,
                      interpolation='None',
                      cmap=plt.cm.cool,
                      vmin=0,
                      vmax=1,
                      alpha=1.)

    unique_values = [1.0]

    float_list_values = list(map(float, unique_values))

    unique_values_categories = ["Calibrated"]

    unique_values_dict = OrderedDict(
        zip(unique_values_categories, unique_values))
    print(unique_values_dict)
    # get the colors of the values, according to the colormap used by imshow
    colors = ['fuchsia']
    # create a patch (proxy artist) for every color
    patches = [
        mpatches.Patch(
            color=colors[i],
            label="{l}".format(l=list(unique_values_dict.keys())[i]))
        for i in range(len(unique_values))
    ]
    #put those patches as legend-handles into the legend
    plt.legend(handles=patches,
               fontsize=12,
               bbox_to_anchor=(0, 0, 0.5, 0.5),
               loc='lower left')  #, borderaxespad=0.5)
    plt.tick_params(
        axis='both',  # changes apply to the x-axis
        which='both',  # both major and minor ticks are affected
        bottom=False,  # ticks along the bottom edge are off
        top=False,
        left=False,  # ticks along the top edge are off
        labelbottom=False,
        labelleft=False)  # labels along the bottom edge are off
    plt.title('Difference in modelled and observed failure times (days)',
              fontsize=20,
              pad=10.)
    norm = mpl.colors.Normalize(vmin=np.amin(valid_mask),
                                vmax=np.amax(valid_mask))
    cmap = plt.cm.jet
    cax = fig.add_axes([0.93, 0.2, 0.02, 0.6])
    cb = mpl.colorbar.ColorbarBase(cax,
                                   cmap=cmap,
                                   norm=norm,
                                   spacing='proportional')
    plt.savefig(fig_name)
    plt.cla()
    freqmap, class_counts = get_freqmap(f)

with open('choices.cash.csv', 'r') as f:
    rnn_confs = get_confusions(f)

with open('cnn-baseline-strict-choices.csv', 'r') as f:
    cnn_confs = get_confusions(f)

mat = np.zeros((334, 334), dtype=float)

for lbl in rnn_confs:
    for pred in rnn_confs[lbl]:
        #if pred != lbl:
        mat[freqmap[lbl], freqmap[pred]] += rnn_confs[lbl][pred]
for lbl in cnn_confs:
    for pred in cnn_confs[lbl]:
        #if pred != lbl:
        mat[freqmap[lbl], freqmap[pred]] -= cnn_confs[lbl][pred]
for cls, ct in class_counts.items():
    mat[freqmap[cls]] = np.true_divide(mat[freqmap[cls]], ct)

fig, ax = plt.subplots()
plt.matshow(mat, cmap='bwr', norm=DivergingNorm(0), fignum=0)
ax.set_ylabel('Labels')
ax.set_xlabel('Predictions')
ax.set_xticks(range(334))
ax.set_yticks(range(334))
ax.set_xticklabels([str(cls) for (cls, idx) in freqmap.items()], rotation=90)
ax.set_yticklabels([str(cls) for (cls, idx) in freqmap.items()])
plt.show()
コード例 #12
0
ファイル: molecule.py プロジェクト: peterspackman/chmpy
    def atomic_stockholder_weight_isosurfaces(self, **kwargs):
        """
        Calculate the stockholder weight isosurfaces for the atoms
        in this molecule, with the provided background density.

        Args:
            kwargs (dict): keyword arguments to be passed to isosurface
                generation code

                Options are:
                ```
                background: float, optional
                    'background' density to ensure closed surfaces for isolated atoms
                    (default=1e-5)
                isovalue: float, optional
                    level set value for the isosurface (default=0.5). Must be between
                    0 and 1, but values other than 0.5 probably won't make sense anyway.
                separation: float, optional
                    separation between density grid used in the surface calculation
                    (default 0.2) in Angstroms.
                radius: float, optional
                    maximum distance for contributing neighbours for the stockholder
                    weight calculation
                color: str, optional
                    surface property to use for vertex coloring, one of ('d_norm_i',
                    'd_i', 'd_norm_e', 'd_e', 'd_norm')
                colormap: str, optional
                    matplotlib colormap to use for surface coloring (default 'viridis_r')
                midpoint: float, optional, default 0.0 if using d_norm
                    use the midpoint norm (as is used in CrystalExplorer)
                ```

        Returns:
            List[trimesh.Trimesh]: A list of meshes representing the stockholder weight isosurfaces
        """

        from chmpy import StockholderWeight
        from chmpy.surface import stockholder_weight_isosurface
        from matplotlib.cm import get_cmap
        import trimesh
        from chmpy.util.color import DEFAULT_COLORMAPS

        sep = kwargs.get("separation", kwargs.get("resolution", 0.2))
        radius = kwargs.get("radius", 12.0)
        background = kwargs.get("background", 1e-5)
        vertex_color = kwargs.get("color", "d_norm_i")
        isovalue = kwargs.get("isovalue", 0.5)
        midpoint = kwargs.get("midpoint",
                              0.0 if vertex_color == "d_norm" else None)
        meshes = []
        colormap = get_cmap(
            kwargs.get("colormap",
                       DEFAULT_COLORMAPS.get(vertex_color, "viridis_r")))
        isos = []
        elements = self.atomic_numbers
        positions = self.positions
        dists = self.distance_matrix

        for n in range(elements.shape[0]):
            els = elements[n:n + 1]
            pos = positions[n:n + 1, :]
            idxs = np.where((dists[n, :] < radius) & (dists[n, :] > 1e-3))[0]
            neighbour_els = elements[idxs]
            neighbour_pos = positions[idxs]

            s = StockholderWeight.from_arrays(els,
                                              pos,
                                              neighbour_els,
                                              neighbour_pos,
                                              background=background)
            iso = stockholder_weight_isosurface(s, isovalue=isovalue, sep=sep)
            isos.append(iso)
        for iso in isos:
            prop = iso.vertex_prop[vertex_color]
            norm = None
            if midpoint is not None:
                from matplotlib.colors import DivergingNorm

                norm = DivergingNorm(vmin=prop.min(),
                                     vcenter=midpoint,
                                     vmax=prop.max())
                prop = norm(prop)
            color = colormap(prop)
            mesh = trimesh.Trimesh(
                vertices=iso.vertices,
                faces=iso.faces,
                normals=iso.normals,
                vertex_colors=color,
            )
            meshes.append(mesh)
        return meshes
コード例 #13
0
def _parse_color_segments(segments, name, hinge=0, colormodel='RGB', N=256):
    """
    A private function to parse color segments.

    Parameters
    ----------
    segments : list
        A list of segments following the GMT structure:

        `z0 color0  z1 color1`

        Where color is either a named color from the GMT color list like
        `black` or `r g b` or `r/g/b`.

    name : str, optional
        name of the returned cmap.

    hinge : float, optional
        Zero by default.

    colormodel : str, optional
        Assumed to be ``'RGB'`` by default.

    N : int, optional
        Number of entries in the look-up-table of the colormap.

    Returns
    -------
    cmap : Colormap
        Either a LinearSegmentedColormap if sequential or
        DynamicColormap if diverging around a ``hinge`` value.
    """
    x = []
    r = []
    g = []
    b = []
    for segment in segments:
        # parse the left side of each segment
        fields = re.split(r'\s+|[/]', segment)
        x.append(float(fields[0]))
        try:
            r.append(float(fields[1]))
            g.append(float(fields[2]))
            b.append(float(fields[3]))
            xi = 4
        except ValueError:
            r_, g_, b_ = GMT_COLOR_NAMES[fields[1]]
            r.append(float(r_))
            g.append(float(g_))
            b.append(float(b_))
            xi = 2

    # parse the right side of the last segment
    x.append(float(fields[xi]))

    try:
        r.append(float(fields[xi + 1]))
        g.append(float(fields[xi + 2]))
        b.append(float(fields[xi + 3]))
    except ValueError:
        r_, g_, b_ = GMT_COLOR_NAMES[fields[-1]]
        r.append(float(r_))
        g.append(float(g_))
        b.append(float(b_))

    x = np.array(x)
    r = np.array(r)
    g = np.array(g)
    b = np.array(b)

    if colormodel == "HSV":
        for i in range(r.shape[0]):
            # convert HSV to RGB
            rr, gg, bb = hsv_to_rgb(r[i] / 360., g[i], b[i])
            r[i] = rr
            g[i] = gg
            b[i] = bb
    elif colormodel == "RGB":
        r /= 255.
        g /= 255.
        b /= 255.
    else:
        raise ValueError('Color model `{}` not understood'.format(colormodel))

    if hinge is not None and x[0] < hinge < x[-1]:
        cmap_type = 'dynamic'
        norm = DivergingNorm(vmin=x[0], vcenter=hinge, vmax=x[-1])
        hinge_index = np.abs(x - hinge).argmin()
    else:
        cmap_type = 'normal'
        hinge = None
        norm = Normalize(vmin=x[0], vmax=x[-1])

    xNorm = norm(x)
    red = []
    blue = []
    green = []
    for i in range(xNorm.size):
        # avoid interpolation across the hinge
        try:
            if i == (hinge_index):
                red.append([xNorm[i], r[i - 1], r[i]])
                green.append([xNorm[i], g[i - 1], g[i]])
                blue.append([xNorm[i], b[i - 1], b[i]])
        except UnboundLocalError:
            pass

        red.append([xNorm[i], r[i], r[i]])
        green.append([xNorm[i], g[i], g[i]])
        blue.append([xNorm[i], b[i], b[i]])

    # return colormap
    cdict = dict(red=red, green=green, blue=blue)
    cmap = LinearSegmentedColormap(name=name, segmentdata=cdict, N=N)
    cmap.values = x
    cmap.colors = list(zip(r, g, b))
    cmap.hinge = hinge
    cmap._init()

    if cmap_type == 'dynamic':
        return DynamicColormap(cmap)
    else:
        return cmap
コード例 #14
0
 def norm(self):
     return DivergingNorm(
         vmin=self.vmin, vcenter=self.hinge, vmax=self.vmax
     )
ax2 = axes[1]
ax3 = axes[2]
ax4 = axes[3]
ax5 = axes[4]
ax6 = axes[5]

#########
# PLOTS #
#########

# SCATTER PLOT
scatter_plot = ax1.scatter(df["dst_lng"].values,
                           df["dst_lat"].values,
                           c=df["frac_c_efficiency"].values,
                           cmap="inferno",
                           norm=DivergingNorm(df["frac_c_efficiency"].mean()),
                           s=3)
fig.colorbar(scatter_plot, ax=ax1, orientation="horizontal",
             pad=0).minorticks_on()
ax1.set_title("Scatter plot")

# LINEAR INTERPOLATION HEATMAP
MAP_RES = 50
xx = np.linspace(df["dst_lng"].min(), df["dst_lng"].max(),
                 (df["dst_lng"].max() - df["dst_lng"].min()) * MAP_RES)
yy = np.linspace(df["dst_lat"].min(), df["dst_lat"].max(),
                 (df["dst_lat"].max() - df["dst_lat"].min()) * MAP_RES)
x, y = np.meshgrid(xx, yy)
z = griddata((df["dst_lng"], df["dst_lat"]),
             df["frac_c_efficiency"], (x, y),
             method="linear",
コード例 #16
0
    full['VminI'] = full['m_f555w'] - full['m_f775u']
    full['VminI_dered'] = full['m_f555w_dered'] - full['m_f775u_dered']    
    
    gsesh = qglue(full=full)
    '''

    ########################################################################
    # Visualize results
    ########################################################################

    #####################
    # Compare new SVM results to Ksoll2018 results
    # need to have run rev_30dor_dered.py to have pcolor0_dered, pmag0_dered
    # of ksoll pms data - TO DO: fix this to work without having done that

    norm = DivergingNorm(vmin=0, vcenter=0.5, vmax=1)  #lsrk
    '''
    fig,[ax2,ax1] = plt.subplots(1,2,figsize=(12,8),sharex=True,sharey=True)
    
    s1=ax1.scatter(X_full[:, 0]-X_full[:, 1], X_full[:, 0], c=full_prob[:,1], 
                cmap='RdYlBu_r',s=0.1,norm=norm)  
    ax1.set_title('New SVM')
    ax1.set_xlabel('F555W - F775W [mag]')
    ax1.set_xlim(-1,3.5)
    ax1.set_ylim(28,12)
    ax1.set_ylabel('F555W [mag]')
    fig.colorbar(s1,ax=ax1,label='P(PMS)')
    
    #ax2.scatter(X_full[:, 0]-X_full[:, 1], X_full[:, 0], c='#333399',s=0.1)  
    #s2 = ax2.scatter(pcolor0_dered,pmag0_dered,s=0.1,c=pms['p_svm'],
    #                 cmap='RdYlBu_r',norm=norm)
コード例 #17
0
row = ["8", "7", "6", "5", "4", "3", "2", "1"]
collumn = ["A", "B", "C", "D", "E", "F", "G", "H"]

vals = np.array([
    [-50, -40, -30, -30, -30, -30, -40, -50],
    [-40, -20, 0, 5, 5, 0, -20, -40],
    [-30, 5, 10, 15, 15, 10, 5, -30],
    [-30, 0, 15, 20, 20, 15, 0, -30],
    [-30, 5, 15, 20, 20, 15, 5, -30],
    [-30, 0, 10, 15, 15, 10, 0, -30],
    [-40, -20, 0, 0, 0, 0, -20, -40],
    [-50, -40, -30, -30, -30, -30, -40, -50],
])
fig, ax = plt.subplots()
im = ax.imshow(vals, cmap="RdYlBu", norm=DivergingNorm(0, vmin=-50, vmax=20))
fig.colorbar(im)  # We want to show all ticks...
ax.set_xticks(np.arange(len(row)))
ax.set_yticks(np.arange(len(collumn)))

ax.set_xticklabels(collumn)
ax.set_yticklabels(row)

plt.setp(
    ax.get_xticklabels(),
    ha="right",
)

for i in range(len(row)):
    for j in range(len(collumn)):
        text = ax.text(j, i, vals[i, j], ha="center", va="center", color="w")
コード例 #18
0
plt.legend()
plt.tight_layout()
#plt.savefig('../run_on_device/benchmark/figures/regime_2_performance_distribution.pdf')
plt.show()

step = step.reshape([672, 672], order='A')
rand_step = rand_step.reshape([672, 672], order='A')

matplotlib.rcParams.update({'font.size': 18})

my_cmap = cm.bwr  # .reversed()
n = 25
extent = [0, 1, 0, 1]
# Overlay the two images
fig, ax = plt.subplots()
im = ax.imshow(rand_step, cmap='bwr', extent=extent, norm=DivergingNorm(n))
clim = im.properties()['clim']
im2 = ax.imshow(rand_step,
                cmap=my_cmap,
                interpolation='none',
                norm=DivergingNorm(n),
                clim=clim,
                extent=extent)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
clb = fig.colorbar(im, ax=ax, cax=cax)
clb.set_label('N', labelpad=-25, y=1.09, rotation=0)
ax.set_xlabel('v5 (mV)', labelpad=-17)
ax.set_ylabel('v9 (mV)', labelpad=-75)
xticks = ['v5_max', 'v5_min']
yticks = ['v9_min', 'v9_max']
コード例 #19
0
ファイル: plot.py プロジェクト: kistoday/minerva
def plot_4_flat(datas, d_list, n_list, fig):
    #fig.suptitle("Average number of liars.", fontweight="bold")
    def map2liars(data):
        N = []
        D = []
        L = []
        for d, n, runs in data:
            l = 0
            for run in runs:
                l += run.liars
            if runs:
                l /= len(runs)
            N.append(n)
            D.append(d)
            L.append(l)
        return N, D, L

    gs = get_gridspec(datas)
    #gs = gridspec.GridSpec(3, 1)
    i = 0
    axes = []
    max_l = 0
    for name, data in datas:
        x, y, z = map2liars(data)
        if max(z) > max_l:
            max_l = max(z)
    viridis = cm.get_cmap('viridis', max_l)
    print(max_l)
    newcolors = viridis(np.linspace(0, 1, max_l))
    pink = np.array([248 / 256, 24 / 256, 148 / 256, 1])
    newcolors[:10, :] = pink
    newcmp = ListedColormap(newcolors)
    divnorm = DivergingNorm(vmin=0, vcenter=6, vmax=max_l)
    cmap = cm.get_cmap("RdBu", max_l / 2)

    for name, data in datas:
        ax = fig.add_subplot(gs[i])
        axes.append(ax)
        x, y, z = map2liars(data)
        n_list = list(range(500, 7100, 100))
        Z = np.empty((len(d_list), len(n_list)))
        for n, d, s in zip(x, y, z):
            if n < 8000:
                Z[d_list.index(d), n_list.index(n)] = s

        im = ax.imshow(Z,
                       aspect="auto",
                       interpolation="nearest",
                       norm=divnorm,
                       cmap=cmap,
                       extent=(min(n_list), max(n_list), min(d_list),
                               max(d_list)),
                       origin="low")
        ax.set_xlabel("Number of signatures (N)")
        ax.set_ylabel("Dimension of matrix (D)")
        ax.set_title(name.split("_")[0])
        i += 1
    #fig.subplots_adjust(top=0.965,bottom=0.065,left=0.110,right=0.845)
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.90, 0.05, 0.05, 0.9])
    #fig.colorbar(im, cax=cbar_ax)
    fig.colorbar(cm.ScalarMappable(norm=divnorm, cmap=cmap), cax=cbar_ax)
コード例 #20
0
ファイル: plot_europe.py プロジェクト: maxnoe/soak19
    z,
    cmap='jet',
    vmin=-5e3,
    vmax=5e3,
    # extent=[-180, 180, -90, 90]
)
img.set_rasterized(True)

fig.colorbar(img)
fig.savefig('build/plots/europe_jet.pdf')

colors_undersea = plt.cm.terrain(np.linspace(0, 0.17, 256))
colors_land = plt.cm.terrain(np.linspace(0.25, 1, 256))
all_colors = np.vstack((colors_undersea, colors_land))
terrain_map = LinearSegmentedColormap.from_list('terrain_map', all_colors)
divnorm = DivergingNorm(vmin=-5e3, vcenter=0, vmax=5e3)

fig = plt.figure(constrained_layout=True)
ax = fig.add_subplot(1, 1, 1)
ax.set_axis_off()

img = plt.imshow(
    z,
    cmap=terrain_map,
    norm=divnorm,
    # extent=[-180, 180, -90, 90]
)
img.set_rasterized(True)

fig.colorbar(img)
fig.savefig('build/plots/europe_divnorm.pdf')
コード例 #21
0
def showMatrix(matrix, x_array=None, y_array=None, **kwargs):
    """Show a matrix using :meth:`~matplotlib.axes.Axes.imshow`. Curves on x- and y-axis can be added.

    :arg matrix: matrix to be displayed
    :type matrix: :class:`~numpy.ndarray`

    :arg x_array: data to be plotted above the matrix
    :type x_array: :class:`~numpy.ndarray`

    :arg y_array: data to be plotted on the left side of the matrix
    :type y_array: :class:`~numpy.ndarray`

    :arg percentile: a percentile threshold to remove outliers, i.e. only showing data within *p*-th 
                     to *100-p*-th percentile
    :type percentile: float

    :arg interactive: turn on or off the interactive options
    :type interactive: bool

    :arg xtickrotation: how much to rotate the xticklabels in degrees
                        default is 0
    :type xtickrotation: float
    """

    from matplotlib import ticker
    from matplotlib.gridspec import GridSpec
    from matplotlib.collections import LineCollection
    from matplotlib.pyplot import gca, sca, sci, colorbar, subplot

    from .drawtools import drawTree, IndexFormatter

    p = kwargs.pop('percentile', None)
    vmin = vmax = None
    if p is not None:
        vmin = np.percentile(matrix, p)
        vmax = np.percentile(matrix, 100 - p)

    vmin = kwargs.pop('vmin', vmin)
    vmax = kwargs.pop('vmax', vmax)
    vcenter = kwargs.pop('vcenter', None)
    norm = kwargs.pop('norm', None)

    if vcenter is not None and norm is None:
        if PY3K:
            try:
                from matplotlib.colors import DivergingNorm
            except ImportError:
                from matplotlib.colors import TwoSlopeNorm as DivergingNorm

            norm = DivergingNorm(vmin=vmin, vcenter=0., vmax=vmax)
        else:
            LOGGER.warn(
                'vcenter cannot be used in Python 2 so norm remains None')

    lw = kwargs.pop('linewidth', 1)

    W = H = kwargs.pop('ratio', 6)

    ticklabels = kwargs.pop('ticklabels', None)
    xticklabels = kwargs.pop('xticklabels', ticklabels)
    yticklabels = kwargs.pop('yticklabels', ticklabels)

    xtickrotation = kwargs.pop('xtickrotation', 0.)

    show_colorbar = kwargs.pop('colorbar', True)
    cb_extend = kwargs.pop('cb_extend', 'neither')
    allticks = kwargs.pop(
        'allticks', False
    )  # this argument is temporary and will be replaced by better implementation
    interactive = kwargs.pop('interactive', True)

    cmap = kwargs.pop('cmap', 'jet')
    origin = kwargs.pop('origin', 'lower')

    try:
        from Bio import Phylo
    except ImportError:
        raise ImportError('Phylo module could not be imported. '
                          'Reinstall ProDy or install Biopython '
                          'to solve the problem.')
    tree_mode_y = isinstance(y_array, Phylo.BaseTree.Tree)
    tree_mode_x = isinstance(x_array, Phylo.BaseTree.Tree)

    if x_array is not None and y_array is not None:
        nrow = 2
        ncol = 2
        i = 1
        j = 1
        width_ratios = [1, W]
        height_ratios = [1, H]
        aspect = 'auto'
    elif x_array is not None and y_array is None:
        nrow = 2
        ncol = 1
        i = 1
        j = 0
        width_ratios = [W]
        height_ratios = [1, H]
        aspect = 'auto'
    elif x_array is None and y_array is not None:
        nrow = 1
        ncol = 2
        i = 0
        j = 1
        width_ratios = [1, W]
        height_ratios = [H]
        aspect = 'auto'
    else:
        nrow = 1
        ncol = 1
        i = 0
        j = 0
        width_ratios = [W]
        height_ratios = [H]
        aspect = kwargs.pop('aspect', None)

    main_index = (i, j)
    upper_index = (i - 1, j)
    left_index = (i, j - 1)

    complex_layout = nrow > 1 or ncol > 1

    ax1 = ax2 = ax3 = None

    if complex_layout:
        gs = GridSpec(nrow,
                      ncol,
                      width_ratios=width_ratios,
                      height_ratios=height_ratios,
                      hspace=0.,
                      wspace=0.)

    ## draw matrix
    if complex_layout:
        ax3 = subplot(gs[main_index])
    else:
        ax3 = gca()

    im = ax3.imshow(matrix,
                    aspect=aspect,
                    vmin=vmin,
                    vmax=vmax,
                    norm=norm,
                    cmap=cmap,
                    origin=origin,
                    **kwargs)

    #ax3.set_xlim([-0.5, matrix.shape[0]+0.5])
    #ax3.set_ylim([-0.5, matrix.shape[1]+0.5])

    if xticklabels is not None:
        ax3.xaxis.set_major_formatter(IndexFormatter(xticklabels))
    if yticklabels is not None and ncol == 1:
        ax3.yaxis.set_major_formatter(IndexFormatter(yticklabels))

    if allticks:
        ax3.xaxis.set_major_locator(ticker.IndexLocator(offset=0.5, base=1.))
        ax3.yaxis.set_major_locator(ticker.IndexLocator(offset=0.5, base=1.))
    else:
        locator = ticker.AutoLocator()
        locator.set_params(integer=True)
        minor_locator = ticker.AutoMinorLocator()

        ax3.xaxis.set_major_locator(locator)
        ax3.xaxis.set_minor_locator(minor_locator)

        locator = ticker.AutoLocator()
        locator.set_params(integer=True)
        minor_locator = ticker.AutoMinorLocator()

        ax3.yaxis.set_major_locator(locator)
        ax3.yaxis.set_minor_locator(minor_locator)

    if ncol > 1:
        ax3.yaxis.set_major_formatter(ticker.NullFormatter())

    ## draw x_ and y_array
    lines = []

    if nrow > 1:
        ax1 = subplot(gs[upper_index])

        if tree_mode_x:
            Y, X = drawTree(x_array,
                            label_func=None,
                            orientation='vertical',
                            inverted=True)
            miny = min(Y.values())
            maxy = max(Y.values())

            minx = min(X.values())
            maxx = max(X.values())

            ax1.set_xlim(minx - .5, maxx + .5)
            ax1.set_ylim(miny, 1.05 * maxy)
        else:
            ax1.set_xticklabels([])

            y = x_array
            xp, yp = interpY(y)
            points = np.array([xp, yp]).T.reshape(-1, 1, 2)
            segments = np.concatenate([points[:-1], points[1:]], axis=1)
            lcy = LineCollection(segments, array=yp, linewidths=lw, cmap=cmap)
            lines.append(lcy)
            ax1.add_collection(lcy)

            ax1.set_xlim(xp.min() - .5, xp.max() + .5)
            ax1.set_ylim(yp.min(), yp.max())

        if ax3.xaxis_inverted():
            ax2.invert_xaxis()

        ax1.axis('off')

    if ncol > 1:
        ax2 = subplot(gs[left_index])

        if tree_mode_y:
            X, Y = drawTree(y_array, label_func=None, inverted=True)
            miny = min(Y.values())
            maxy = max(Y.values())

            minx = min(X.values())
            maxx = max(X.values())

            ax2.set_ylim(miny - .5, maxy + .5)
            ax2.set_xlim(minx, 1.05 * maxx)
        else:
            ax2.set_xticklabels([])

            y = y_array
            xp, yp = interpY(y)
            points = np.array([yp, xp]).T.reshape(-1, 1, 2)
            segments = np.concatenate([points[:-1], points[1:]], axis=1)
            lcx = LineCollection(segments, array=yp, linewidths=lw, cmap=cmap)
            lines.append(lcx)
            ax2.add_collection(lcx)
            ax2.set_xlim(yp.min(), yp.max())
            ax2.set_ylim(xp.min() - .5, xp.max() + .5)

        ax2.invert_xaxis()

        if ax3.yaxis_inverted():
            ax2.invert_yaxis()

        ax2.axis('off')

    ## draw colorbar
    sca(ax3)
    cb = None
    if show_colorbar:
        if nrow > 1:
            axes = [ax1, ax2, ax3]
            while None in axes:
                axes.remove(None)
            s = H / (H + 1.)
            cb = colorbar(mappable=im,
                          ax=axes,
                          anchor=(0, 0),
                          shrink=s,
                          extend=cb_extend)
        else:
            cb = colorbar(mappable=im, extend=cb_extend)

    sca(ax3)
    sci(im)

    if interactive:
        from prody.utilities import ImageCursor
        from matplotlib.pyplot import connect
        cursor = ImageCursor(ax3, im)
        connect('button_press_event', cursor.onClick)

    ax3.tick_params(axis='x', rotation=xtickrotation)

    return im, lines, cb
コード例 #22
0
ファイル: pca.py プロジェクト: indebetouw/magellanic-pms
def pca_clumps(clumps_df,ncomps=3,pca_stats=False,plot=True):
    
    # Apply PCA
    pca = PCA(n_components = ncomps)
    pca_fit = pca.fit_transform(df_scaled)
    clus_principal = pd.DataFrame(pca_fit)
    #clus_principal.columns = ['P1','P2','P3','P4','P5','P6','P7','P8']

    if pca_stats == True:
        
        print(clus_principal.head())
        print(pca.explained_variance_ratio_)
    
    if plot == True:
        
        # check correlations
        cor = pd.DataFrame(df).corr()    
        mask = np.triu(np.ones_like(cor,dtype=np.bool))
        plt.figure()
        sn.heatmap(cor,annot=True,mask=mask,vmin=-1,vmax=1)
        plt.xticks(rotation=15)
        plt.yticks(rotation=55)
        plt.title('Correlations')
        
        # effects of each dimension on factor plot
        plt.figure(figsize=(8,4))
        plt.imshow(pca.components_,interpolation='none',cmap='plasma',vmin=-1,vmax=1)
        feature_names = list(clumps_df.columns)
        #plt.gca().set_xticks(np.arange(-.5, len(feature_names)));
        #plt.gca().set_yticks(np.arange(0.5, 9));
        plt.gca().set_xticklabels(feature_names, rotation=60, ha='left', fontsize=10);
        plt.gca().set_yticklabels(['PC1', 'PC2','PC3','PC4'], \
               va='bottom', fontsize=12);
        plt.colorbar(orientation='horizontal', ticks=[pca.components_.min(), 0,
                                                      pca.components_.max()], pad=0.2)
        plt.title('Impact of Observables on PCs')
            
        # cumulative explained variance
        plt.figure()
        plt.plot(np.linspace(1,3,3),np.cumsum(pca.explained_variance_ratio_))
        plt.xlabel('Number of Components')
        plt.ylabel('Cumulative Explained Variance')    
        plt.title('Cumulative Explained Variance of PCs')
        plt.ylim(0,1)
        #plt.xlim(1,8)
        
        # scree plot
        plt.figure()
        plt.plot(np.linspace(1,3,3),pca.explained_variance_)
        plt.xlabel('Number of Components')
        plt.ylabel('Eigenvalue')    
        plt.title('Scree Plot of PCs')
        #plt.ylim(0,1)
        #plt.xlim(1,8)
        
        # scatter
        #plt.scatter(clus_principal['P1'],clus_principal['P2'])
        
        score = pca_fit[:,0:2]
        coeff = np.transpose(pca.components_[0:2, :])
        labels = list(clumps_df.columns)
        
        # biplot
        plt.figure()
        xs = score[:,0]
        ys = score[:,1]
        n = coeff.shape[0]
        scalex = 1.0/(xs.max() - xs.min())
        scaley = 1.0/(ys.max() - ys.min())

        norm = DivergingNorm(vmin=0, vcenter=0.5,vmax=1) 
        plt.scatter(xs * scalex,ys * scaley,c=full_prob[:,1],
                    cmap='RdYlBu_r',s=0.1,alpha=0.5,norm=norm)

        for i in range(n):
            plt.arrow(0, 0, coeff[i,0], coeff[i,1],color = 'r',alpha = 0.5)
            if labels is None:
                plt.text(coeff[i,0]* 1.15, coeff[i,1] * 1.15, "Var"+str(i+1), color = 'green', ha = 'center', va = 'center')
            else:
                plt.text(coeff[i,0]* 1.15, coeff[i,1] * 1.15, labels[i], color = 'g', ha = 'center', va = 'center')
     
        plt.xlabel("PC{}".format(1))
        plt.ylabel("PC{}".format(2))
        plt.grid()
        #plt.ylim(-1,1)
        #plt.xlim(-0.4,0.8)
        plt.title('Biplot')
                           
    return clus_principal