示例#1
0
 def draw_transect(self,outpath,fname,radar=True):
     B = BirdsEye(self.W)
     m,x,y = B.basemap_setup()
     m.drawgreatcircle(self.lonA,self.latA,self.lonB,self.latB)
     # tv = 'Q_pert'
     tv = 'cref';lv=False
     # lv = 800
     # clvs = N.arange(-0.005,0.0052,0.0002)
     # cmap='BrBG'
     S = Scales('cref',False)
     clvs = S.clvs
     cmap = S.cm
     m.contourf(x,y,self.W.get(tv,utc=self.tidx,level=lv)[0,0,:,:],levels=clvs,cmap=cmap)
     B.save(outpath,fname)
     plt.close(B.fig)
示例#2
0
 def get_scales(scales=None, vrbl=None, lvs=None):
     if scales:
         S = scales
         cmap = S.cm
         clvs = S.clvs
     else:
         if vrbl:
             S = Scales(vrbl)
             cmap = S.cm
             clvs = S.clvs
         else:
             cmap = None
             clvs = None
     if lvs is not None:
         clvs = lvs
     return cmap, clvs
示例#3
0
    def _get_plot_options1(self,vrbl=None,*args,**kwargs):
        """ Filter arguments and key-word arguments for plotting methods.

        Whatever is in dictionary will overwrite defaults in the plotting
        method.

        These may be
          * fhrs (forecast hour plotting times - or all)
          * ensmems (ensemble members, or all)


        """
        # Get plotting levels if not already given
        # TODO: rewrite this using hasattr() or something.
        if vrbl:
            S = Scales(vrbl)
            if not 'levels' in kwargs:
                kwargs['levels'] = S.clvs
            if not 'cmap' in kwargs:
                kwargs['cmap'] = S.cm

        # Specific things for certain variables
        if vrbl in ('REFL_10CM',"REFL_comp"):
            pass

        # Save all figures to a subdirectory
        if 'subdir' in kwargs:
            utils.trycreate(subdir,is_folder=True)

        # What times are being plotted?
        # If all, return list of all times
        if 'utc' in kwargs:
            pass
        elif ('fchr' not in kwargs) or (kwargs['fchr'] == 'all'):
            kwargs['utc'] = E.list_of_times
        # Does this pick up on numpy arange?
        elif isinstance(kwargs['fchr'], (list,tuple)):
            kwargs['utc'] = []
            for f in kwargs['fchr']:
                utc = self.inittime + datetime.timedelta(seconds=3600*f)
                kwargs['fchr'].append(utc)

        # Make domain smaller if requested

        # Save data before plotting
        clskwargs['save_data'] = kwargs.get('save_data',False)
        return clskwargs,plotkwargs,mplkwargs
示例#4
0
dataroot = '/scratch/john.lawson/WRF/VSE_reso/20160331'
initutc = datetime.datetime(2016,3,31,21,0,0)
outdir = '/home/john.lawson/VSE_reso/pyoutput/attempt1'
fhrs = [0.25,] + list(range(1,3))
plotlist = ['Verif',''] + list(range(1,11))

### OPTIONS ###
thumbnails = ['REFL_comp',]
scores = ['CRPS',]

### INSTANCES etc ###
E = Ensemble(dataroot,initutc,ctrl=False,loadobj=False,doms=2)
ST4 = StageIV(st4dir)
d01_limdict = E.get_limits(dom=1)
d02_limdict = E.get_limits(dom=2)
S = Scales('REFL_comp')

### FUNCTIONS ###
def plot_thumbnails(plotutc,vrbl):
    """ Plot data, including verification if it exists.

    There are three plotting options.
    The d01 and d02 domains, as is, and a third
    interpolation to common (inner) domain.
    """
    if vrbl == 'REFL_comp':
        R_large = Radar(plotutc,radardir)
        R_small = copy.copy(R_large)
        R_large.get_subdomain(**d01_limdict,overwrite=True)
        R_small.get_subdomain(**d02_limdict,overwrite=True)
    for nloop,dom in enumerate((1,1,2)):
示例#5
0
        axit = iter(axes.flatten())
        for folder in sorted(FOLDERS):
            if "netcdf4" in folder:
                ncno = 4
            elif "nco" in folder:
                ncno = " ncrcat"
            else:
                ncno = 3
            for mem in sorted(MEMS):
                ax = next(axit)
                ss = folder.split('/')[-1]
                fname = 'wrfout_{}_2017-05-03_02:00:00'.format(dom)

                lvidx = 0
                if vrbl == 'REFL_comp':
                    S = Scales(vrbl=vrbl)
                    cmap = S.cm
                    lvs = S.clvs
                elif vrbl == 'WSPD10MAX':
                    cmap = 'gist_earth_r'
                    lvs = N.arange(2.5, 22.5, 2.5)
                elif vrbl == 'W':
                    lvidx = 20
                else:
                    raise Exception

                fpath = os.path.join(folder, mem, fname)
                # ax.set_title("Dom {} for the {} run".format(dom,ss),fontsize=6)
                ax.set_title("{} for netCDF{}".format(mem, ncno))

                if ("onedom" in ss) and (dom == 'd02'):
示例#6
0
    def plot(
            self,
            fpath,
            fmt='default',
            W=None,
            vrbl='REFL_comp',
            # Nlim=None,Elim=None,Slim=None,Wlim=None):
            ld=None,
            lats=None,
            lons=None,
            fig=None,
            ax=None):
        """ Plot basic quicklook images.

        Setting fmt to 'default' will plot raw data,
        plus objects identified.
        """
        if ld is None:
            ld = dict()
        nobjs = len(self.objects)

        if fmt == 'default':
            # if fig is None:
            F = Figure(ncols=2, nrows=1, figsize=(8, 4), fpath=fpath)
            # F.W = W
            with F:
                ax = F.ax[0]
                # Plot raw array
                BE = BirdsEye(ax=ax, fig=F.fig)

                # Discrete colormap
                import matplotlib as M
                cmap_og = M.cm.get_cmap('tab20')
                # cmap_colors = [cmap_og(i) for i in range(cmap_og.N)]
                color_list = cmap_og(N.linspace(0, 1, nobjs))
                # cmap = M.colors.ListedColormap(M.cm.tab20,N=len(self.objects))
                cmap = M.colors.LinearSegmentedColormap.from_list(
                    'discrete_objects', color_list, nobjs)
                # bounds = N.linspace(0,nobjs,nobjs+1)
                # norm = M.colors.BoundaryNorm(bounds,cmap_og.N)
                masked_objs = N.ma.masked_less(self.obj_array, 1)

                BE.plot2D(
                    plottype='pcolormesh',
                    data=masked_objs,
                    save=False,
                    cb='horizontal',
                    #clvs=N.arange(1,nobjs),
                    W=W,
                    cmap=cmap,
                    mplkwargs={'vmin': 1},
                    **ld,
                    lats=lats,
                    lons=lons)

                ax = F.ax[1]
                S = Scales(vrbl)
                BE = BirdsEye(ax=ax, fig=F.fig)
                BE.plot2D(data=self.raw_data,
                          save=False,
                          W=W,
                          cb='horizontal',
                          lats=lats,
                          lons=lons,
                          cmap=S.cm,
                          clvs=S.clvs,
                          **ld)
        return
示例#7
0
    def plot_quicklook(self, outdir, what='all', fname=None, ecc=0.2):
        """ Plot quick images of objects identified.

        Args:
            what (str): the type of quickplot
            outdir (str): where to put the quickplot

            fname (str,optional): if None, automatically name.
            ecc (float): eccentricity of object, for discriminating (later)
        """
        assert what in ("all", "qlcs", "ecc", "shapeindex", "ratio", "extent",
                        "4-panel", "pca")

        def label_objs(bmap, ax, locs):
            for k, v in locs.items():
                xpt, ypt = bmap(v[1], v[0])
                # bbox_style = {'boxstyle':'square','fc':'white','alpha':0.5}
                # bmap.plot(xpt,ypt,'ko',)#markersize=3,zorder=100)
                # ax.text(xpt,ypt,k,ha='left',fontsize=15)
                ax.annotate(k,
                            xy=(xpt, ypt),
                            xycoords="data",
                            zorder=1000,
                            fontsize=7,
                            fontweight='bold')
                # pdb.set_trace()
            return

        def do_label_array(bmap, ax):
            locs = dict()
            for o in self.objects.itertuples():
                locs[str(int(o.label))] = (o.centroid_lat, o.centroid_lon)
            idxarray = N.ma.masked_where(self.idxarray < 1, self.idxarray)
            bmap.pcolormesh(
                data=idxarray,
                x=x,
                y=y,
            )
            #vmin=1,vmax=self.idxarray.max()+1)
            #        levels=N.arange(1,self.objects.shape[0]))
            label_objs(bmap, ax, locs)
            ax.set_title("Object labels")
            return

        def __do_label(bmap, ax):
            locs = dict()
            for o, okidx in zip(self.objects.itertuples(), self.OKidxs):
                #data = skimage.preprocessing.normalize(
                raw_data = self.object_props[okidx].intensity_image
                pdb.set_trace()
                data = N.linalg.norm(axis=(0, 1), x=raw_data)
                shidx = skimage.feature.shape_index(data)
                sistr = "{:0.3f}".format(shidx)
                locs[sistr] = (o.centroid_lat, o.centroid_lon)
            idxarray = N.ma.masked_where(self.idxarray < 1, self.idxarray)
            bmap.pcolormesh(
                data=idxarray,
                x=x,
                y=y,
            )
            #vmin=1,vmax=self.idxarray.max()+1)
            #        levels=N.arange(1,self.objects.shape[0]))
            label_objs(bmap, ax, locs)
            ax.set_title("Object shape index")
            return

        def do_label_ecc(bmap, ax):
            locs = dict()
            for o in self.objects.itertuples():
                eccstr = "{:0.2f}".format(o.eccentricity)
                locs[eccstr] = (o.centroid_lat, o.centroid_lon)
            idxarray = N.ma.masked_where(self.OKidxarray < 1, self.OKidxarray)
            bmap.pcolormesh(data=idxarray, x=x, y=y, alpha=0.5)
            #vmin=1,vmax=self.idxarray.max()+1)
            #        levels=N.arange(1,self.objects.shape[0]))
            label_objs(bmap, ax, locs)
            ax.set_title("Object eccentricity")
            return

        def do_label_extent(bmap, ax):
            locs = dict()
            for o in self.objects.itertuples():
                lab = "{:1.2f}".format(o.extent)
                locs[lab] = (o.centroid_lat, o.centroid_lon)
            idxarray = N.ma.masked_where(self.OKidxarray < 1, self.OKidxarray)
            bmap.pcolormesh(data=idxarray, x=x, y=y, alpha=0.5)
            #vmin=1,vmax=self.idxarray.max()+1)
            #        levels=N.arange(1,self.objects.shape[0]))
            label_objs(bmap, ax, locs)
            ax.set_title("Object extent (fill pc. of bbox)")
            return

        def do_label_longest(bmap, ax):
            locs = dict()
            for o in self.objects.itertuples():
                lab = "{}".format(int(o.longaxis_km))
                locs[lab] = (o.centroid_lat, o.centroid_lon)
            idxarray = N.ma.masked_where(self.OKidxarray < 1, self.OKidxarray)
            bmap.pcolormesh(data=idxarray, x=x, y=y, alpha=0.5)
            #vmin=1,vmax=self.idxarray.max()+1)
            #        levels=N.arange(1,self.objects.shape[0]))
            label_objs(bmap, ax, locs)
            ax.set_title("Object longest-side length (km)")
            return

        def do_label_pca(bmap, ax, discrim_vals=(-0.2, 0.5)):
            locs = dict()
            obj_discrim = N.zeros_like(self.OKidxarray)
            for o in self.objects.itertuples():
                lab = "{:0.2f}".format(o.qlcsness)
                locs[lab] = (o.centroid_lat, o.centroid_lon)
                id = int(o.label)
                if o.qlcsness < discrim_vals[0]:
                    discrim = 1
                elif o.qlcsness > discrim_vals[1]:
                    discrim = 3
                else:
                    discrim = 2
                obj_discrim = N.where(self.idxarray == id, discrim,
                                      obj_discrim)
            marr = N.ma.masked_where(obj_discrim < 1, obj_discrim)
            pcm = bmap.pcolormesh(
                data=marr,
                x=x,
                y=y,
                alpha=0.5,
                vmin=1,
                vmax=3,
                cmap=M.cm.get_cmap("magma", 3),
            )
            #vmin=1,vmax=self.idxarray.max()+1)
            #        levels=N.arange(1,self.objects.shape[0]))

            mode_names = ["", "Cellular", "Ambiguous", "Linear/Complex"]

            def format_func(x, y):
                return mode_names[x]

            # This function formatter will replace integers with target names

            formatter = plt.FuncFormatter(lambda val, loc: mode_names[val])

            # We must be sure to specify the ticks matching our target names
            cax = fig.add_axes([0.75, 0.1, 0.2, 0.04])
            #cax.set_xlim(0.5, 3.5)
            fig.colorbar(pcm,
                         cax=cax,
                         ticks=(1, 2, 3),
                         format=formatter,
                         orientation='horizontal')

            # Set the clim so that labels are centered on each block

            label_objs(bmap, ax, locs)
            #plt.colorbar(pcm,cax=cax)
            ax.set_title("Object PC1 (QLCS-ness)")
            return

        def do_label_ratio(bmap, ax):
            locs = dict()
            for o in self.objects.itertuples():
                lab = "{:1.2f}".format(o.ratio)
                locs[lab] = (o.centroid_lat, o.centroid_lon)
            idxarray = N.ma.masked_where(self.OKidxarray < 1, self.OKidxarray)
            bmap.pcolormesh(data=idxarray, x=x, y=y, alpha=0.5)
            #vmin=1,vmax=self.idxarray.max()+1)
            #        levels=N.arange(1,self.objects.shape[0]))
            label_objs(bmap, ax, locs)
            ax.set_title("Object side-ratio (short/long)")
            return

        def do_raw_array(bmap, ax):
            bmap.contourf(data=self.raw_data,
                          x=x,
                          y=y,
                          levels=S.clvs,
                          cmap=S.cm)
            ax.set_title("Raw data")
            return

        def do_intensity_array(bmap, ax):
            locs = dict()
            for o in self.objects.itertuples():
                locs[str(int(o.label))] = (o.weighted_centroid_lat,
                                           o.weighted_centroid_lon)
            bmap.contourf(data=self.object_field,
                          x=x,
                          y=y,
                          levels=S.clvs,
                          cmap=S.cm)
            label_objs(bmap, ax, locs)
            ax.set_title("Intensity array")
            return

        def do_table(bmap, ax):
            cell_text = []
            table = self.objects
            for row in range(len(table)):
                cell_text.append(table.iloc[row])

            tab = plt.table(cellText=cell_text,
                            colLabels=table.columns,
                            loc='center')
            ax.add_table(tab)
            plt.axis('off')
            return

        def do_highlow_ecc(bmap, ax, ecc, overunder):
            assert overunder in ("over", "under")
            func = N.greater_equal if overunder == "over" else N.less
            locs = dict()
            qlcs_obj = []
            obj_field = N.zeros_like(self.object_field)
            for o in self.objects.itertuples():
                if func(o.eccentricity, ecc):
                    locs[str(int(o.label))] = (o.centroid_lat, o.centroid_lon)
                    qlcs_obj.append(o.label)
            for olab in qlcs_obj:
                obj_field = N.where(self.idxarray == olab, self.object_field,
                                    obj_field)
            bmap.contourf(data=obj_field, x=x, y=y, levels=S.clvs, cmap=S.cm)
            label_objs(bmap, ax, locs)
            if overunder == "over":
                ax.set_title("QLCS objects (ecc > {:0.2f})".format(ecc))
            else:
                ax.set_title("Cellular objects (ecc < {:0.2f})".format(ecc))
            return

        # The above needs saving to pickle and loading/copying each time
        # a plot is made, if optimising.

        # 2x2:
        # ax1 is the raw field
        # ax2 is the labelled objects (ID)
        # ax3 is the object field, annotate row/col and lat/lon centroids
        # ax4 is the DataFrame?
        if what == "4-panel":
            fig, axes = plt.subplots(2, 2, figsize=(9, 7))
        else:
            fig, axes = plt.subplots(1, 3, figsize=(9, 4))

        if fname is None:
            fname = "quicklook.png"

        for n, ax in enumerate(axes.flat):
            # if n!=0:
            #     continue
            print("Plotting subplot #", n + 1)
            bmap = self.create_bmap(ax=ax)
            x, y = bmap(self.lons, self.lats)
            S = Scales(vrbl='REFL_comp')

            if n == 0:
                do_raw_array(bmap, ax)

            elif n == 1:
                if what == "qlcs":
                    do_highlow_ecc(bmap, ax, ecc, "over")
                elif what in ("ecc", "extent", "4-panel"):
                    do_label_ecc(bmap, ax)
                else:
                    do_intensity_array(bmap, ax)
                #elif what == "ecc":
                #    do_intensity_array(bmap,ax)

            elif n == 2:
                if what == "qlcs":
                    do_highlow_ecc(bmap, ax, ecc, "under")
                elif what == "all":
                    do_label_array(bmap, ax)
                elif what == "ecc":
                    #do_label_ecc(bmap,ax)
                    do_label_ratio(bmap, ax)
                elif what == "shapeindex":
                    do_label_shapeindex(bmap, ax)
                elif what == "ratio":
                    do_label_ratio(bmap, ax)
                elif what in ("4-panel", "extent"):
                    do_label_extent(bmap, ax)
                elif what == "pca":
                    do_label_pca(bmap, ax)

            elif n == 3:
                if what == "4-panel":
                    do_label_longest(bmap, ax)
                else:
                    do_table(bmap, ax)

        fpath = os.path.join(outdir, fname)
        utils.trycreate(fpath)
        fig.tight_layout()
        fig.savefig(fpath)
        print("Saved figure to", fpath)
        plt.close(fig)
        # pdb.set_trace()
        return